Commit 50fe58fa authored by wxj's avatar wxj
Browse files

Update transformer.py

parent 52610942
Pipeline #2151 passed with stage
...@@ -582,13 +582,13 @@ class ParallelAttention(MegatronModule): ...@@ -582,13 +582,13 @@ class ParallelAttention(MegatronModule):
else: else:
kv_projection_size = args.kv_channels * args.num_attention_heads kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = (args.use_flash_attn_ck or args.use_flash_attn_triton) \ self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton) \
and attention_type == AttnType.self_attn \ and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton self.use_flash_attn_triton = args.use_flash_attn_triton
if self.use_flash_attn: if self.use_flash_attn:
if args.use_flash_attn_ck: if args.use_flash_attn_cutlass:
if flash_attn_unpadded_func is None: if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with ' raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn') 'pip install flash-attn')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment