Unverified Commit a462fc92 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Bloom Optimize operations (#17866)



* fix tolerance for a bloom slow test

* enhance alibi padding

- get rid of for loops
- deals better with padded batched input
- avoid useless cpu/gpu communication when creating alibi
Co-authored-by: default avatarjustheuristic <justheuristic@gmail.com>

* optimize attention mask

* fix scaled softmax limit values

* optimize building alibi tensor
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* fix attention_mask shape when it's None

* minor fixes

- fix docstring + arg names

* remove colons in docstring

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* apply suggestion

* remove unsued arg

* refactor a bit

- use [:, None] for consistency

* refactor attention block
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>

* quick fixes

* first attempt

* refactor attention block and fix all tests except "test_simple_generation"

- added comments to better explain attention block

* remove debug lines and add TODO comment

* change `torch.bmm` to `torch.baddbmm`
- fixes `test_simple_generation`but breaks `test_batch_generation_padd`

* styling

* all tests are passing now
- use `bmm`
- add explanation for `allow_fp16_reduced_precision_reduction`
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* styling
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* fix support for accelerate
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove attn softmax in fp32

* refactor comments

* refactor a bit

- remove warning message
- remove print on test

* refer to pytorch t5

* change the slow tests

- do the tests in fp32
- remove some comments
- keep large comments

* update expected output for `test_simple_generation`
- we now test using fp32

* make style + change comments a bit

* fix dtype padd test
Co-authored-by: default avatarjustheuristic <justheuristic@gmail.com>
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 5ff6f853
...@@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig): ...@@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig):
If set to `True`, it will skip bias add for each linear layer in the transformer blocks If set to `True`, it will skip bias add for each linear layer in the transformer blocks
skip_bias_add_qkv (`bool`, *optional*, defaults to `False`): skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
If set to `True`, it will skip bias add for the first linear layer in the transformer blocks If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to
`fp32`
hidden_dropout (`float`, *optional*, defaults to 0.1): hidden_dropout (`float`, *optional*, defaults to 0.1):
Dropout rate of the dropout function on the bias dropout. Dropout rate of the dropout function on the bias dropout.
attention_dropout (`float`, *optional*, defaults to 0.1): attention_dropout (`float`, *optional*, defaults to 0.1):
...@@ -128,7 +125,6 @@ class BloomConfig(PretrainedConfig): ...@@ -128,7 +125,6 @@ class BloomConfig(PretrainedConfig):
hidden_size=64, hidden_size=64,
n_layer=2, n_layer=2,
n_head=8, n_head=8,
masked_softmax_fusion=True,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=False, use_cache=False,
...@@ -137,7 +133,6 @@ class BloomConfig(PretrainedConfig): ...@@ -137,7 +133,6 @@ class BloomConfig(PretrainedConfig):
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
attention_softmax_in_fp32=True,
pretraining_tp=1, # TP rank used when training with megatron pretraining_tp=1, # TP rank used when training with megatron
dtype="bfloat16", dtype="bfloat16",
slow_but_exact=False, slow_but_exact=False,
...@@ -147,7 +142,6 @@ class BloomConfig(PretrainedConfig): ...@@ -147,7 +142,6 @@ class BloomConfig(PretrainedConfig):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.masked_softmax_fusion = masked_softmax_fusion
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_cache = use_cache self.use_cache = use_cache
...@@ -155,7 +149,6 @@ class BloomConfig(PretrainedConfig): ...@@ -155,7 +149,6 @@ class BloomConfig(PretrainedConfig):
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
......
...@@ -377,15 +377,34 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -377,15 +377,34 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_simple_generation(self): def test_simple_generation(self):
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
# We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms
# This discrepancy is observed only when using small models and seems to be stable for larger models.
# Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models.
# Here is a summary of an ablation study of our observations
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a"
# 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love"
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False)
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS
# >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
path_350m = "bigscience/bloom-350m" path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval() model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m) tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
input_sentence = "I enjoy walking with my cute dog" input_sentence = "I enjoy walking with my cute dog"
# This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU
EXPECTED_OUTPUT = ( EXPECTED_OUTPUT = (
"I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am" "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very "
" a very good listener. I am a very good person, and I am a very good person. I am a" "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I"
) )
input_ids = tokenizer.encode(input_sentence, return_tensors="pt") input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
...@@ -397,7 +416,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -397,7 +416,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@require_torch_gpu @require_torch_gpu
def test_batch_generation(self): def test_batch_generation(self):
path_350m = "bigscience/bloom-350m" path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval() model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
...@@ -416,8 +435,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -416,8 +435,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_batch_generation_padd(self): def test_batch_generation_padd(self):
path_350m = "bigscience/bloom-350m" path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval() model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment