Unverified Commit 0a0dd34e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix BumpAllocator error when no input_ids (#5564)

parent 80ac527d
......@@ -94,7 +94,9 @@ class DeepseekModelNextN(nn.Module):
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
......
......@@ -1374,7 +1374,9 @@ class DeepseekV2Model(nn.Module):
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment