"googlemock/vscode:/vscode.git/clone" did not exist on "1df907381de5483def338eff280697e0dafb2b75"
Unverified Commit 16a469df authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Llama tutorial fixes (#730)



Llama tutorial fixes - all
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 12cbd863
......@@ -56,7 +56,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
num_gqa_groups=config.num_key_value_heads
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
......@@ -121,12 +121,12 @@ class TELlamaForCausalLM:
assert not isinstance(resolved_archive_file, list)
resolved_archive_file = [resolved_archive_file]
error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
replaced_layers = replace_params(state_dict, vanilla_model.state_dict())
error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# Force mem release. Taken from huggingface code
del state_dict
......@@ -134,7 +134,7 @@ class TELlamaForCausalLM:
return vanilla_model
def replace_params(hf_state_dict, te_state_dict):
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
......@@ -142,32 +142,40 @@ def replace_params(hf_state_dict, te_state_dict):
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in TE model
if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict:
if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]
if layer_prefix + 'self_attention.proj.weight' in te_state_dict:
if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]
if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict:
if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0)
if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict:
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data
if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data
if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
return all_layer_prefixes
\ No newline at end of file
......@@ -231,7 +231,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
......@@ -556,7 +557,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
......@@ -635,7 +637,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
......
......@@ -91,6 +91,7 @@ def init_te_llama_model(hyperparams):
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
config=config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment