Unverified Commit 8f7078e8 authored by LiWenhan's avatar LiWenhan Committed by GitHub
Browse files

make tensors in function build_relative_position created on proper device...


make tensors in function build_relative_position created on proper device instead of always on cpu (#20434)
Co-authored-by: default avatarwenhanli <wenhanli@tencent.com>
parent de4159a3
......@@ -487,7 +487,11 @@ class DebertaV2Encoder(nn.Module):
if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position(
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
q,
hidden_states.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=hidden_states.device,
)
return relative_pos
......@@ -589,7 +593,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position):
return bucket_pos
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
"""
Build relative position according to the query and key
......@@ -602,13 +606,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
key_size (int): the length of key
bucket_size (int): the size of position bucket
max_position (int): the maximum allowed absolute position
device (`torch.device`): the device on which tensors will be created.
Return:
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
q_ids = torch.arange(0, query_size, device=device)
k_ids = torch.arange(0, key_size, device=device)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
......@@ -778,7 +783,11 @@ class DisentangledSelfAttention(nn.Module):
if relative_pos is None:
q = query_layer.size(-2)
relative_pos = build_relative_position(
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
q,
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=query_layer.device,
)
if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
......@@ -835,7 +844,8 @@ class DisentangledSelfAttention(nn.Module):
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
).to(query_layer.device)
device=query_layer.device,
)
r_pos = r_pos.unsqueeze(0)
else:
r_pos = relative_pos
......
......@@ -209,7 +209,7 @@ def make_log_bucket_position(relative_pos, bucket_size, max_position):
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.build_relative_position
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
"""
Build relative position according to the query and key
......@@ -222,13 +222,14 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
key_size (int): the length of key
bucket_size (int): the size of position bucket
max_position (int): the maximum allowed absolute position
device (`torch.device`): the device on which tensors will be created.
Return:
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
q_ids = torch.arange(0, query_size, device=device)
k_ids = torch.arange(0, key_size, device=device)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
......@@ -827,7 +828,11 @@ class DisentangledSelfAttention(nn.Module):
if relative_pos is None:
q = query_layer.size(-2)
relative_pos = build_relative_position(
q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
q,
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=query_layer.device,
)
if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
......@@ -884,7 +889,8 @@ class DisentangledSelfAttention(nn.Module):
key_layer.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
).to(query_layer.device)
device=query_layer.device,
)
r_pos = r_pos.unsqueeze(0)
else:
r_pos = relative_pos
......@@ -1093,7 +1099,11 @@ class SEWDTransformerEncoder(nn.Module):
if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position(
q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
q,
hidden_states.size(-2),
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=hidden_states.device,
)
return relative_pos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment