Unverified Commit 9818f85f authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #255 from beginlner/main

Fix a bug
parents 85b51d61 8e44c0ee
...@@ -119,7 +119,7 @@ class Block(nn.Module): ...@@ -119,7 +119,7 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer. about the CLS token in the last layer.
""" """
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm) fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm) else dropout_add_layer_norm)
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment