Unverified Commit a3345c1f authored by Peter's avatar Peter Committed by GitHub
Browse files

Add `accelerate` support for LongT5 models (#20341)

* 

 add accelerate support for LongT5 models
Signed-off-by: default avatarpeter szemraj <peterszemraj@gmail.com>

* fix `accelerate` tests

* Trigger CI test
Signed-off-by: default avatarpeter szemraj <peterszemraj@gmail.com>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 8286af6f
...@@ -648,9 +648,12 @@ class LongT5LocalAttention(nn.Module): ...@@ -648,9 +648,12 @@ class LongT5LocalAttention(nn.Module):
def compute_bias(self, block_length: int): def compute_bias(self, block_length: int):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
memory_position = torch.arange( target_device = (
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device self.relative_attention_bias.weight.device
if self.relative_attention_bias.weight.device.type != "meta"
else None
) )
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
context_position = memory_position[block_length:-block_length] context_position = memory_position[block_length:-block_length]
# (block_length, 3 * block_length) # (block_length, 3 * block_length)
...@@ -843,9 +846,12 @@ class LongT5TransientGlobalAttention(nn.Module): ...@@ -843,9 +846,12 @@ class LongT5TransientGlobalAttention(nn.Module):
def compute_bias(self, block_length: int): def compute_bias(self, block_length: int):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
memory_position = torch.arange( target_device = (
3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device self.relative_attention_bias.weight.device
if self.relative_attention_bias.weight.device.type != "meta"
else None
) )
memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
context_position = memory_position[block_length:-block_length] context_position = memory_position[block_length:-block_length]
# (block_length, 3 * block_length) # (block_length, 3 * block_length)
...@@ -1271,6 +1277,7 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1271,6 +1277,7 @@ class LongT5PreTrainedModel(PreTrainedModel):
config_class = LongT5Config config_class = LongT5Config
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LongT5Block"]
@property @property
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
...@@ -1366,7 +1373,9 @@ class LongT5Stack(LongT5PreTrainedModel): ...@@ -1366,7 +1373,9 @@ class LongT5Stack(LongT5PreTrainedModel):
def __init__(self, config, embed_tokens=None): def __init__(self, config, embed_tokens=None):
super().__init__(config) super().__init__(config)
self.embed_tokens = embed_tokens self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.local_radius = config.local_radius self.local_radius = config.local_radius
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment