"vscode:/vscode.git/clone" did not exist on "9c29e938d7bf5e8b40d9a5f861a6b7a6a7c7c7d7"
Unverified Commit 0a0dd34e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix BumpAllocator error when no input_ids (#5564)

parent 80ac527d
...@@ -94,7 +94,9 @@ class DeepseekModelNextN(nn.Module): ...@@ -94,7 +94,9 @@ class DeepseekModelNextN(nn.Module):
zero_allocator = BumpAllocator( zero_allocator = BumpAllocator(
buffer_size=2, buffer_size=2,
dtype=torch.float32, dtype=torch.float32,
device=input_ids.device, device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
) )
if input_embeds is None: if input_embeds is None:
......
...@@ -1374,7 +1374,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1374,7 +1374,9 @@ class DeepseekV2Model(nn.Module):
# TODO for two-batch-overlap, we need a larger buffer size # TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2, buffer_size=len(self.layers) * 2,
dtype=torch.float32, dtype=torch.float32,
device=input_ids.device, device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
) )
if input_embeds is None: if input_embeds is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment