Unverified Commit 9818f85f authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #255 from beginlner/main

Fix a bug
parents 85b51d61 8e44c0ee
......@@ -119,7 +119,7 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm:
if not self.fused_dropout_add_ln:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment