"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9beaa85b071078f84037f6a036ea042f551a8623"
Unverified Commit bc34c211 authored by Niklas Muennighoff's avatar Niklas Muennighoff Committed by GitHub
Browse files

Fix BLOOM dtype (#17995)

* Add fp16 option

* Fix BLOOM dtype

* Formatting

* Remove torch_dtype arg

* Revert formatting

* Apply formatting

* Add n_embed backward compat
parent 981714ef
...@@ -78,9 +78,6 @@ class BloomConfig(PretrainedConfig): ...@@ -78,9 +78,6 @@ class BloomConfig(PretrainedConfig):
Dropout rate applied to the attention probs Dropout rate applied to the attention probs
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
dtype (`str`, *optional*, defaults to `"bfloat16"`):
Precision that has been used for the model's training in Megatron. Please load the model in the correct
precision by doing `model = BloomModel.from_pretrained(model_name, torch_dtype="auto")`.`
pretraining_tp (`int`, *optional*, defaults to `1`): pretraining_tp (`int`, *optional*, defaults to `1`):
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
...@@ -114,9 +111,7 @@ class BloomConfig(PretrainedConfig): ...@@ -114,9 +111,7 @@ class BloomConfig(PretrainedConfig):
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = { attribute_map = {
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
"n_head": "num_attention_heads", "num_attention_heads": "n_head",
"hidden_size": "n_embed",
"dtype": "torch_dtype",
} }
def __init__( def __init__(
...@@ -134,12 +129,13 @@ class BloomConfig(PretrainedConfig): ...@@ -134,12 +129,13 @@ class BloomConfig(PretrainedConfig):
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
pretraining_tp=1, # TP rank used when training with megatron pretraining_tp=1, # TP rank used when training with megatron
dtype="bfloat16",
slow_but_exact=False, slow_but_exact=False,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size # Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
...@@ -152,7 +148,6 @@ class BloomConfig(PretrainedConfig): ...@@ -152,7 +148,6 @@ class BloomConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.dtype = dtype
self.slow_but_exact = slow_but_exact self.slow_but_exact = slow_but_exact
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......
...@@ -203,7 +203,8 @@ def convert_bloom_checkpoint_to_pytorch( ...@@ -203,7 +203,8 @@ def convert_bloom_checkpoint_to_pytorch(
os.makedirs(pytorch_dump_folder_path, exist_ok=True) os.makedirs(pytorch_dump_folder_path, exist_ok=True)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path}") print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
model = model.to(config.torch_dtype)
torch.save(model.state_dict(), pytorch_weights_dump_path) torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}") print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment