Unverified Commit 993a187c authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

fix device in longformer onnx path (#20419)

parent bc00c29d
......@@ -425,7 +425,7 @@ class LEDEncoderSelfAttention(nn.Module):
hidden_states.size(2),
]
overlapping_chunks = torch.empty(chunk_size)
overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
......
......@@ -796,7 +796,7 @@ class LongformerSelfAttention(nn.Module):
hidden_states.size(2),
]
overlapping_chunks = torch.empty(chunk_size)
overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
for chunk in range(chunk_size[1]):
overlapping_chunks[:, chunk, :, :] = hidden_states[
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment