Unverified Commit 8d259fad authored by Anna Shors's avatar Anna Shors Committed by GitHub
Browse files

Fix gpt oss weight loading with EP + bf16 (#28765)


Signed-off-by: default avatarashors1 <ashors@nvidia.com>
parent 3bc11757
...@@ -494,8 +494,8 @@ class GptOssModel(nn.Module): ...@@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
def _load_weights_other( def _load_weights_other(
self, self,
ep_rank_start: int,
ep_rank_end: int, ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int, heads_per_rank: int,
head_start: int, head_start: int,
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment