Unverified Commit b69a62d5 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

[BLOOM] Clean modeling code (#18344)



* Cleanup some code

* Improve signatures

* Try to reduce the number of reshape/copies

* I don't think we actually need the layer_num scaling trick

* No need for duplication

* Try to fix beam_search

* Fix beam search

* Removing layer num normalization seems to be breaking

* Not sure self.layer_number normalization actually matters

* Try and be backward compatible

* Try to fix beam_search

* Revert attempt to be backward compatible

* Improve documentation on past_key_values format

* Optimize the device allocation in case of hidden_states in multiple devices

* No need to manually cast the values to a specific device

* Rename with long version of variables

* Improve type hinting

* Add comment that explains that some methods return views

* Actually i think the attention casting only makes sense when we use torch.float16

* We don't actually need layer_number to be passed anymore

* Fix FX test

* Bypass torch.baddbmm

* Apply suggestions from code review

* Add comment about support for torchScript v1.11

* fix ONNX support for bloom (#18456)
Co-authored-by: default avatarNiklas Muennighoff <n.muennighoff@gmail.com>
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>
parent 02b176c4
...@@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast): ...@@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast):
batch, seqlen = common_inputs["input_ids"].shape batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values # Not using the same length for past_key_values
past_key_values_length = seqlen + 2 past_key_values_length = seqlen + 2
past_shape = ( head_dim = self._config.hidden_size // self.num_attention_heads
batch, past_key_shape = (
batch * self.num_attention_heads,
head_dim,
past_key_values_length, past_key_values_length,
self.num_attention_heads, )
self._config.hidden_size // self.num_attention_heads, past_value_shape = (
batch * self.num_attention_heads,
past_key_values_length,
head_dim,
) )
ordered_inputs["past_key_values"] = [ ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
] ]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment