Unverified Commit 767e6b53 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix gptj could not jit.trace in GPU (#23317)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent b4698b7e
......@@ -212,7 +212,7 @@ class GPTJAttention(nn.Module):
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids):
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment