Unverified Commit 16a469df authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Llama tutorial fixes (#730)



Llama tutorial fixes - all
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 12cbd863
...@@ -56,7 +56,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): ...@@ -56,7 +56,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
normalization="RMSNorm", normalization="RMSNorm",
activation="swiglu", activation="swiglu",
attn_input_format="bshd", attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads, num_gqa_groups=config.num_key_value_heads
) )
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
...@@ -121,12 +121,12 @@ class TELlamaForCausalLM: ...@@ -121,12 +121,12 @@ class TELlamaForCausalLM:
assert not isinstance(resolved_archive_file, list) assert not isinstance(resolved_archive_file, list)
resolved_archive_file = [resolved_archive_file] resolved_archive_file = [resolved_archive_file]
error_msgs = []
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file) state_dict = load_state_dict(shard_file)
replaced_layers = replace_params(state_dict, vanilla_model.state_dict()) # replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# Force mem release. Taken from huggingface code # Force mem release. Taken from huggingface code
del state_dict del state_dict
...@@ -134,7 +134,7 @@ class TELlamaForCausalLM: ...@@ -134,7 +134,7 @@ class TELlamaForCausalLM:
return vanilla_model return vanilla_model
def replace_params(hf_state_dict, te_state_dict): def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update # collect all layer prefixes to update
all_layer_prefixes = set() all_layer_prefixes = set()
for param_key in hf_state_dict.keys(): for param_key in hf_state_dict.keys():
...@@ -142,32 +142,40 @@ def replace_params(hf_state_dict, te_state_dict): ...@@ -142,32 +142,40 @@ def replace_params(hf_state_dict, te_state_dict):
m = re.match(layer_prefix_pat, param_key) m = re.match(layer_prefix_pat, param_key)
if m is not None: if m is not None:
all_layer_prefixes.add(m.group()) all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes: for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the # When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in TE model # copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict: if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict: if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict: if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict: if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]
if layer_prefix + 'self_attention.proj.weight' in te_state_dict: if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]
if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict: if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict: # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) # load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data
if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data
if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
return all_layer_prefixes return all_layer_prefixes
\ No newline at end of file
...@@ -231,7 +231,8 @@ ...@@ -231,7 +231,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
...@@ -556,7 +557,8 @@ ...@@ -556,7 +557,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
...@@ -635,7 +637,8 @@ ...@@ -635,7 +637,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"fp8\"\n", "hyperparams.mixed_precision = \"fp8\"\n",
"\n", "\n",
......
...@@ -91,6 +91,7 @@ def init_te_llama_model(hyperparams): ...@@ -91,6 +91,7 @@ def init_te_llama_model(hyperparams):
# Init the model # Init the model
from te_llama import TELlamaForCausalLM from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name) config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local( model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name, hyperparams.model_name,
config=config, config=config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment