"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2bd950ca47e1687d524c0fab1a7b83b10e55f458"
Unverified Commit 767e6b53 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix gptj could not jit.trace in GPU (#23317)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent b4698b7e
...@@ -212,7 +212,7 @@ class GPTJAttention(nn.Module): ...@@ -212,7 +212,7 @@ class GPTJAttention(nn.Module):
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids): if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this # The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case # every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids) embed_positions = get_embed_positions(self.embed_positions, position_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment