Unverified Commit c988548f authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch] Fix garbage initialized permuted_scale (#2547)


Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
parent 27dc83bf
...@@ -165,7 +165,7 @@ def permute_with_mask_map( ...@@ -165,7 +165,7 @@ def permute_with_mask_map(
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
) )
permuted_scale = ( permuted_scale = (
torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
if scale is not None if scale is not None
else None else None
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment