"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "5a8fed914f96c05f299b46142fed5b76ba6e6db4"
Commit b3177dfa authored by Tri Dao's avatar Tri Dao
Browse files

[GPT] Enable FlashAttention for GPT-J

parent 6fc1e07d
...@@ -276,6 +276,9 @@ class ParallelBlock(nn.Module): ...@@ -276,6 +276,9 @@ class ParallelBlock(nn.Module):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._shared_params = True p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None): residual: Optional[Tensor] = None, mixer_kwargs=None):
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
......
...@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name): ...@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
...@@ -93,7 +93,7 @@ def test_gptj_generation(model_name): ...@@ -93,7 +93,7 @@ def test_gptj_generation(model_name):
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment