Unverified Commit 2dbfbc74 authored by Santosh Bhavani's avatar Santosh Bhavani Committed by GitHub
Browse files

fix(examples): te_llama compatibility with transformers >= 4.57 (#2572)



* fix(examples): te_llama compatibility with HuggingFace transformers >= 4.57

The te_llama.py example was failing with HuggingFace transformers 4.57+
due to API changes in how decoder layer outputs are handled.

Changes:
- Handle case where hidden_states is passed as a tuple (older HF versions)
- Return tensor directly instead of wrapped in tuple (HF 4.57+ expects this)
- Fix regex pattern to use raw string (fixes SyntaxWarning)

Error fixed:
  AttributeError: 'tuple' object has no attribute 'contiguous'

Tested with:
- transformer_engine 2.5.0
- transformers 4.57.3
- PyTorch container nvcr.io/nvidia/pytorch:25.08-py3
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

* docs(te_llama): add requirements.txt
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

* fix(docs): add missing notebook output names
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

---------
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>
parent 72592763
transformers==4.57.0
accelerate==1.10.0
peft==0.15.2
datasets==4.0.0
sentencepiece==0.2.1
......@@ -72,10 +72,15 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
),
# Handle case where hidden_states might be a tuple (from previous layer output)
# This can happen with older versions of HuggingFace transformers
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Return tensor directly for HuggingFace transformers >= 4.57
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
return super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
)
......@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
layer_prefix_pat = r"model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment