Unverified Commit 97c9a41f authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmark - Update TE FP8 model conversion (#499)

__Description__

Update TE FP8 model conversion.

__Major Revisions__
* Add 16-byte alignment comment.
* Fix TE layer parameters type.
parent c88c9709
...@@ -63,15 +63,16 @@ def _to_te_model(self, model): ...@@ -63,15 +63,16 @@ def _to_te_model(self, model):
return return
for name, m in model.named_children(): for name, m in model.named_children():
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
# check 16-byte alignment
if any(p % 16 != 0 for p in m.weight.shape): if any(p % 16 != 0 for p in m.weight.shape):
return return
te_m = te.Linear(m.in_features, m.out_features, bias=(m.bias is not None)) te_m = te.Linear(m.in_features, m.out_features, bias=(m.bias is not None), params_dtype=m.weight.dtype)
te_m.weight.copy_(m.weight) te_m.weight.copy_(m.weight)
if m.bias is not None: if m.bias is not None:
te_m.bias.copy_(m.bias) te_m.bias.copy_(m.bias)
setattr(model, name, te_m) setattr(model, name, te_m)
elif isinstance(m, torch.nn.LayerNorm): elif isinstance(m, torch.nn.LayerNorm):
te_m = te.LayerNorm(m.normalized_shape[0], eps=m.eps) te_m = te.LayerNorm(m.normalized_shape[0], eps=m.eps, params_dtype=m.weight.dtype)
if hasattr(te_m, 'weight'): if hasattr(te_m, 'weight'):
te_m.weight.copy_(m.weight) te_m.weight.copy_(m.weight)
te_m.bias.copy_(m.bias) te_m.bias.copy_(m.bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment