Unverified Commit d6eeb871 authored by Karim Foda's avatar Karim Foda Committed by GitHub
Browse files

Flax Remat for LongT5 (#17994)



* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* add gradient_checkpointing to examples

* Add gradient_checkpointing to run_mlm_flax

* Add remat to longt5

* Add gradient checkpointing test longt5

* Fix args errors

* Fix remaining tests

* Make fixup & quality fixes

* replace kwargs

* remove unecessary kwargs

* Make fixup changes

* revert long_t5_flax changes

* Remove return_dict and copy to LongT5

* Remove test_gradient_checkpointing
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
parent 1ccd2515
...@@ -107,6 +107,12 @@ class TrainingArguments: ...@@ -107,6 +107,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
) )
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self): def __post_init__(self):
if self.output_dir is not None: if self.output_dir is not None:
...@@ -640,6 +646,9 @@ def main(): ...@@ -640,6 +646,9 @@ def main():
dtype=getattr(jnp, model_args.dtype), dtype=getattr(jnp, model_args.dtype),
) )
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
......
...@@ -121,6 +121,12 @@ class TrainingArguments: ...@@ -121,6 +121,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
) )
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self): def __post_init__(self):
if self.output_dir is not None: if self.output_dir is not None:
...@@ -535,6 +541,9 @@ def main(): ...@@ -535,6 +541,9 @@ def main():
dtype=getattr(jnp, model_args.dtype), dtype=getattr(jnp, model_args.dtype),
) )
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......
...@@ -25,6 +25,7 @@ import jax ...@@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "google/long-t5-local-base" ...@@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
_CONFIG_FOR_DOC = "LongT5Config" _CONFIG_FOR_DOC = "LongT5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"
remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
...@@ -1356,7 +1359,6 @@ class FlaxLongT5LayerCollection(nn.Module): ...@@ -1356,7 +1359,6 @@ class FlaxLongT5LayerCollection(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
output_attentions=False, output_attentions=False,
return_dict=True,
deterministic=True, deterministic=True,
init_cache=False, init_cache=False,
): ):
...@@ -1377,11 +1379,29 @@ class FlaxLongT5LayerCollection(nn.Module): ...@@ -1377,11 +1379,29 @@ class FlaxLongT5LayerCollection(nn.Module):
class FlaxLongT5BlockCollection(nn.Module): class FlaxLongT5BlockCollection(nn.Module):
config: LongT5Config config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.causal = self.config.causal self.causal = self.config.causal
if self.gradient_checkpointing:
FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [ self.blocks = [
FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i)) FlaxLongT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxLongT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers) for i in range(self.config.num_layers)
] ]
...@@ -1409,14 +1429,14 @@ class FlaxLongT5BlockCollection(nn.Module): ...@@ -1409,14 +1429,14 @@ class FlaxLongT5BlockCollection(nn.Module):
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask,
position_bias=position_bias, position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias,
output_attentions=output_attentions, output_attentions,
deterministic=deterministic, deterministic,
init_cache=init_cache, init_cache,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module): ...@@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
config: LongT5Config config: LongT5Config
embed_tokens: nn.Embed embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.causal = self.config.causal self.causal = self.config.causal
self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype) self.block = FlaxLongT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxLongT5LayerNorm( self.final_layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
) )
...@@ -1989,6 +2012,7 @@ LONGT5_START_DOCSTRING = r""" ...@@ -1989,6 +2012,7 @@ LONGT5_START_DOCSTRING = r"""
class FlaxLongT5Module(nn.Module): class FlaxLongT5Module(nn.Module):
config: LongT5Config config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self): def _get_encoder_module(self):
return self.encoder return self.encoder
...@@ -2005,12 +2029,22 @@ class FlaxLongT5Module(nn.Module): ...@@ -2005,12 +2029,22 @@ class FlaxLongT5Module(nn.Module):
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False encoder_config.causal = False
self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) self.encoder = FlaxLongT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
decoder_config = copy.deepcopy(self.config) decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) self.decoder = FlaxLongT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -2104,6 +2138,7 @@ append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutpu ...@@ -2104,6 +2138,7 @@ append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutpu
class FlaxLongT5ForConditionalGenerationModule(nn.Module): class FlaxLongT5ForConditionalGenerationModule(nn.Module):
config: LongT5Config config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self): def _get_encoder_module(self):
return self.encoder return self.encoder
...@@ -2124,13 +2159,17 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module): ...@@ -2124,13 +2159,17 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module):
encoder_config.causal = False encoder_config.causal = False
encoder_config.use_cache = False encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False encoder_config.is_encoder_decoder = False
self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype) self.encoder = FlaxLongT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config) decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True decoder_config.causal = True
decoder_config.is_encoder_decoder = False decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype) self.decoder = FlaxLongT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense( self.lm_head = nn.Dense(
self.config.vocab_size, self.config.vocab_size,
......
...@@ -25,6 +25,7 @@ import jax ...@@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "t5-small" ...@@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "t5-small"
_CONFIG_FOR_DOC = "T5Config" _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"
remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
...@@ -622,7 +625,6 @@ class FlaxT5LayerCollection(nn.Module): ...@@ -622,7 +625,6 @@ class FlaxT5LayerCollection(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
output_attentions=False, output_attentions=False,
return_dict=True,
deterministic=True, deterministic=True,
init_cache=False, init_cache=False,
): ):
...@@ -642,11 +644,29 @@ class FlaxT5LayerCollection(nn.Module): ...@@ -642,11 +644,29 @@ class FlaxT5LayerCollection(nn.Module):
class FlaxT5BlockCollection(nn.Module): class FlaxT5BlockCollection(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.causal = self.config.causal self.causal = self.config.causal
if self.gradient_checkpointing:
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [ self.blocks = [
FlaxT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i)) FlaxT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers) for i in range(self.config.num_layers)
] ]
...@@ -674,14 +694,14 @@ class FlaxT5BlockCollection(nn.Module): ...@@ -674,14 +694,14 @@ class FlaxT5BlockCollection(nn.Module):
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask,
position_bias=position_bias, position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias,
output_attentions=output_attentions, output_attentions,
deterministic=deterministic, deterministic,
init_cache=init_cache, init_cache,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -711,11 +731,14 @@ class FlaxT5Stack(nn.Module): ...@@ -711,11 +731,14 @@ class FlaxT5Stack(nn.Module):
config: T5Config config: T5Config
embed_tokens: nn.Embed embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.causal = self.config.causal self.causal = self.config.causal
self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype) self.block = FlaxT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxT5LayerNorm( self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
) )
...@@ -919,11 +942,19 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): ...@@ -919,11 +942,19 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -1248,6 +1279,7 @@ T5_START_DOCSTRING = r""" ...@@ -1248,6 +1279,7 @@ T5_START_DOCSTRING = r"""
class FlaxT5Module(nn.Module): class FlaxT5Module(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self): def _get_encoder_module(self):
return self.encoder return self.encoder
...@@ -1264,12 +1296,22 @@ class FlaxT5Module(nn.Module): ...@@ -1264,12 +1296,22 @@ class FlaxT5Module(nn.Module):
encoder_config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
decoder_config = copy.deepcopy(self.config) decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) self.decoder = FlaxT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -1364,6 +1406,7 @@ append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, c ...@@ -1364,6 +1406,7 @@ append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, c
class FlaxT5EncoderModule(nn.Module): class FlaxT5EncoderModule(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.shared = nn.Embed( self.shared = nn.Embed(
...@@ -1376,7 +1419,12 @@ class FlaxT5EncoderModule(nn.Module): ...@@ -1376,7 +1419,12 @@ class FlaxT5EncoderModule(nn.Module):
encoder_config.is_decoder = False encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False encoder_config.is_encoder_decoder = False
encoder_config.causal = False encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -1384,7 +1432,7 @@ class FlaxT5EncoderModule(nn.Module): ...@@ -1384,7 +1432,7 @@ class FlaxT5EncoderModule(nn.Module):
attention_mask=None, attention_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict: bool = True,
deterministic: bool = True, deterministic: bool = True,
): ):
...@@ -1445,6 +1493,7 @@ class FlaxT5EncoderModel(FlaxT5PreTrainedModel): ...@@ -1445,6 +1493,7 @@ class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
class FlaxT5ForConditionalGenerationModule(nn.Module): class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self): def _get_encoder_module(self):
return self.encoder return self.encoder
...@@ -1465,13 +1514,17 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): ...@@ -1465,13 +1514,17 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
encoder_config.causal = False encoder_config.causal = False
encoder_config.use_cache = False encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False encoder_config.is_encoder_decoder = False
self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype) self.encoder = FlaxT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config) decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True decoder_config.causal = True
decoder_config.is_encoder_decoder = False decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype) self.decoder = FlaxT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense( self.lm_head = nn.Dense(
self.config.vocab_size, self.config.vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment