Unverified Commit 8f7078e8 authored by LiWenhan's avatar LiWenhan Committed by GitHub
Browse files

make tensors in function build_relative_position created on proper device...


make tensors in function build_relative_position created on proper device instead of always on cpu (#20434)
Co-authored-by: default avatarwenhanli <wenhanli@tencent.com>
parent de4159a3
...@@ -487,7 +487,11 @@ class DebertaV2Encoder(nn.Module): ...@@ -487,7 +487,11 @@ class DebertaV2Encoder(nn.Module):
if self.relative_attention and relative_pos is None: if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions q,
hidden_states.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=hidden_states.device,
) )
return relative_pos return relative_pos
...@@ -589,7 +593,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position): ...@@ -589,7 +593,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position):
return bucket_pos return bucket_pos
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
""" """
Build relative position according to the query and key Build relative position according to the query and key
...@@ -602,13 +606,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- ...@@ -602,13 +606,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
key_size (int): the length of key key_size (int): the length of key
bucket_size (int): the size of position bucket bucket_size (int): the size of position bucket
max_position (int): the maximum allowed absolute position max_position (int): the maximum allowed absolute position
device (`torch.device`): the device on which tensors will be created.
Return: Return:
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size) q_ids = torch.arange(0, query_size, device=device)
k_ids = torch.arange(0, key_size, device=device)
rel_pos_ids = q_ids[:, None] - k_ids[None, :] rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
...@@ -778,7 +783,11 @@ class DisentangledSelfAttention(nn.Module): ...@@ -778,7 +783,11 @@ class DisentangledSelfAttention(nn.Module):
if relative_pos is None: if relative_pos is None:
q = query_layer.size(-2) q = query_layer.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions q,
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=query_layer.device,
) )
if relative_pos.dim() == 2: if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
...@@ -835,7 +844,8 @@ class DisentangledSelfAttention(nn.Module): ...@@ -835,7 +844,8 @@ class DisentangledSelfAttention(nn.Module):
key_layer.size(-2), key_layer.size(-2),
bucket_size=self.position_buckets, bucket_size=self.position_buckets,
max_position=self.max_relative_positions, max_position=self.max_relative_positions,
).to(query_layer.device) device=query_layer.device,
)
r_pos = r_pos.unsqueeze(0) r_pos = r_pos.unsqueeze(0)
else: else:
r_pos = relative_pos r_pos = relative_pos
......
...@@ -209,7 +209,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position): ...@@ -209,7 +209,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position):
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.build_relative_position # Copied from transformers.models.deberta_v2.modeling_deberta_v2.build_relative_position
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
""" """
Build relative position according to the query and key Build relative position according to the query and key
...@@ -222,13 +222,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- ...@@ -222,13 +222,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
key_size (int): the length of key key_size (int): the length of key
bucket_size (int): the size of position bucket bucket_size (int): the size of position bucket
max_position (int): the maximum allowed absolute position max_position (int): the maximum allowed absolute position
device (`torch.device`): the device on which tensors will be created.
Return: Return:
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size) q_ids = torch.arange(0, query_size, device=device)
k_ids = torch.arange(0, key_size, device=device)
rel_pos_ids = q_ids[:, None] - k_ids[None, :] rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
...@@ -827,7 +828,11 @@ class DisentangledSelfAttention(nn.Module): ...@@ -827,7 +828,11 @@ class DisentangledSelfAttention(nn.Module):
if relative_pos is None: if relative_pos is None:
q = query_layer.size(-2) q = query_layer.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions q,
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=query_layer.device,
) )
if relative_pos.dim() == 2: if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
...@@ -884,7 +889,8 @@ class DisentangledSelfAttention(nn.Module): ...@@ -884,7 +889,8 @@ class DisentangledSelfAttention(nn.Module):
key_layer.size(-2), key_layer.size(-2),
bucket_size=self.position_buckets, bucket_size=self.position_buckets,
max_position=self.max_relative_positions, max_position=self.max_relative_positions,
).to(query_layer.device) device=query_layer.device,
)
r_pos = r_pos.unsqueeze(0) r_pos = r_pos.unsqueeze(0)
else: else:
r_pos = relative_pos r_pos = relative_pos
...@@ -1093,7 +1099,11 @@ class SEWDTransformerEncoder(nn.Module): ...@@ -1093,7 +1099,11 @@ class SEWDTransformerEncoder(nn.Module):
if self.relative_attention and relative_pos is None: if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position( relative_pos = build_relative_position(
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions q,
hidden_states.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=hidden_states.device,
) )
return relative_pos return relative_pos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment