Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
...@@ -2,26 +2,24 @@ ...@@ -2,26 +2,24 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox from transformers.models.gpt_neox.modeling_gpt_neox import (
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
# NeoX-style rotary embedding # NeoX-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711]) @pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset): def test_rotary(rotary_emb_fraction, seqlen_offset):
device = 'cuda' device = "cuda"
dtype = torch.float16 dtype = torch.float16
rtol, atol = (1e-3, 5e-3) rtol, atol = (1e-3, 5e-3)
# set seed # set seed
...@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset): ...@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset):
nheads = 16 nheads = 16
headdim = 128 headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction) rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device) rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device) rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total) cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype) cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim], q_pt = (
'b s h d -> b h s d').detach().clone().requires_grad_(True) rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim], .detach()
'b s h d -> b h s d').detach().clone().requires_grad_(True) .clone()
.requires_grad_(True)
)
k_pt = (
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset) q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset) out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype), assert torch.allclose(
rtol=rtol, atol=atol) rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype), )
rtol=rtol, atol=atol) assert torch.allclose(
assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim], rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
rtol=rtol, atol=atol) )
assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim], assert torch.allclose(
rtol=rtol, atol=atol) rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out) g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g) out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d')) q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d')) k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'), assert torch.allclose(
qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) rearrange(q_pt.grad, "b h s d -> b s h d"),
assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'), qkv.grad[:, :, 0, :, :rotary_dim],
qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) rtol=rtol,
atol=atol,
)
assert torch.allclose(
rearrange(k_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 1, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2]) assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding # GPT-J-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711]) @pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = 'cuda' device = "cuda"
dtype = torch.float16 dtype = torch.float16
rtol, atol = (1e-3, 5e-3) rtol, atol = (1e-3, 5e-3)
# set seed # set seed
...@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): ...@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
nheads = 16 nheads = 16
headdim = 128 headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction) rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device) rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total) sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
......
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLossApex from flash_attn.losses.cross_entropy import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True]) @pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False]) # @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
@pytest.mark.parametrize('vocab_size', [50257]) @pytest.mark.parametrize("vocab_size", [50257])
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
device = 'cuda' device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 128 seqlen = 128
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True) x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100 y[torch.randperm(batch_size * seqlen)[:10]] = -100
......
...@@ -3,35 +3,37 @@ ...@@ -3,35 +3,37 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True]) @pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False]) # @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9]) # @pytest.mark.parametrize('smoothing', [0.9])
@pytest.mark.parametrize('vocab_size', [50264]) @pytest.mark.parametrize("vocab_size", [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype): def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 rtol, atol = (
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) (1e-5, 1e-6)
if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
partition_vocab_size = vocab_size // world_size partition_vocab_size = vocab_size // world_size
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_ ...@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 128 seqlen = 128
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device, x_pt = (
dtype=dtype) * 10).requires_grad_() torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_() ).requires_grad_()
x = (
tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100 y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none') model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none', model = CrossEntropyLoss(
label_smoothing=smoothing,
reduction="none",
inplace_backward=inplace_backward, inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group()) process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y) out = model(x, y)
out_pt = model_pt(x_pt.float(), y) out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
...@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_ ...@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
g = torch.randn_like(out) g = torch.randn_like(out)
out_pt.backward(g) out_pt.backward(g)
out.backward(g) out.backward(g)
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol) assert torch.allclose(
x.grad,
x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
import re import re
from collections import OrderedDict from collections import OrderedDict
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name): def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
...@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name): ...@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name):
def get_hf_models(model_name, config, dtype): def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name) pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key): def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key) key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key) key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items()) pretrained_state_dict = OrderedDict(
(key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
)
model_hf = BertForPreTrainingHF(config) model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias" # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias. # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
...@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype): ...@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name): def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the """Check that our implementation of BERT (without any optimizations enabled) matches the
...@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name): ...@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
...@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name): ...@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name):
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask) out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}') print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}') print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}') print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item() assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item() sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name): def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the """Check that our implementation of BERT (with all optimizations enabled) matches the
...@@ -117,10 +121,11 @@ def test_bert_optimized(model_name): ...@@ -117,10 +121,11 @@ def test_bert_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
...@@ -131,12 +136,24 @@ def test_bert_optimized(model_name): ...@@ -131,12 +136,24 @@ def test_bert_optimized(model_name):
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0 sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}') print(
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}') f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}') )
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}') print(
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item() f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item() )
print(
f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
)
print(
f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
)
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
out = model(input_ids, attention_mask=attention_mask) out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
...@@ -144,25 +161,43 @@ def test_bert_optimized(model_name): ...@@ -144,25 +161,43 @@ def test_bert_optimized(model_name):
prediction_scores = prediction_scores.clone() prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0 prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask) out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf[~attention_mask, :] = 0.0 prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask) out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref[~attention_mask, :] = 0.0 prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}') print(
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}') f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}') )
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}') print(
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item() f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item() )
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True]) @pytest.mark.parametrize("last_layer_subset", [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True]) # @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False]) @pytest.mark.parametrize("has_key_padding_mask", [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True]) # @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset): def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the """Check that our implementation of BERT (with all optimizations enabled) matches the
...@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs ...@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
if has_key_padding_mask: if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
else: else:
attention_mask = None attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, )
device='cuda') labels = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if attention_mask is not None: if attention_mask is not None:
labels[~attention_mask] = 0 labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0 labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0 masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda') next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
out = model( out = model(
input_ids, attention_mask=attention_mask, input_ids,
labels=labels, next_sentence_label=next_sequence_label attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
) )
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask, out_hf = model_hf(
labels=labels, next_sentence_label=next_sequence_label) input_ids,
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits attention_mask=attention_mask,
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask] labels=labels,
out_ref = model_ref(input_ids, attention_mask=attention_mask, next_sentence_label=next_sequence_label,
labels=labels, next_sentence_label=next_sequence_label) )
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits prediction_scores_hf, seq_relationship_scores_hf = (
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask] out_hf.prediction_logits,
out_hf.seq_relationship_logits,
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}') )
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}') prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}') out_ref = model_ref(
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}') input_ids,
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item() attention_mask=attention_mask,
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item() labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0. # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item() # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
...@@ -3,44 +3,46 @@ ...@@ -3,44 +3,46 @@
import os import os
import time import time
from pathlib import Path from pathlib import Path
current_dir = Path(__file__).parent.absolute() current_dir = Path(__file__).parent.absolute()
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.falcon import remap_state_dict_hf_falcon, falcon_config_to_gpt2_config
from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b", "tiiuae/falcon-40b"]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
def test_falcon_state_dict(model_name): def test_falcon_state_dict(model_name):
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, config = falcon_config_to_gpt2_config(
trust_remote_code=True)) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config) )
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys(): for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_optimized(model_name): def test_falcon_optimized(model_name):
"""Check that our implementation (with all optimizations enabled) matches the """Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, config = falcon_config_to_gpt2_config(
trust_remote_code=True)) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_mlp = False # We don't have fused MLP for "gelu" activation
...@@ -53,8 +55,9 @@ def test_falcon_optimized(model_name): ...@@ -53,8 +55,9 @@ def test_falcon_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
logits = model(input_ids).logits logits = model(input_ids).logits
...@@ -78,30 +81,33 @@ def test_falcon_optimized(model_name): ...@@ -78,30 +81,33 @@ def test_falcon_optimized(model_name):
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
del model_hf del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward" # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32. # memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_forward(model_name, world_size): def test_falcon_parallel_forward(model_name, world_size):
from apex.transformer import parallel_state from apex.transformer import parallel_state
dtype = torch.float16 dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, config = falcon_config_to_gpt2_config(
trust_remote_code=True)) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False config.use_flash_attn = False
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_mlp = False # We don't have fused MLP for "gelu" activation
...@@ -109,14 +115,16 @@ def test_falcon_parallel_forward(model_name, world_size): ...@@ -109,14 +115,16 @@ def test_falcon_parallel_forward(model_name, world_size):
config.residual_in_fp32 = True config.residual_in_fp32 = True
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
...@@ -126,8 +134,9 @@ def test_falcon_parallel_forward(model_name, world_size): ...@@ -126,8 +134,9 @@ def test_falcon_parallel_forward(model_name, world_size):
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group) out, _ = all_gather_raw(out, process_group=process_group)
...@@ -135,7 +144,7 @@ def test_falcon_parallel_forward(model_name, world_size): ...@@ -135,7 +144,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits = model(input_ids).logits logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group) logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model del model
if rank == 0: if rank == 0:
...@@ -157,29 +166,32 @@ def test_falcon_parallel_forward(model_name, world_size): ...@@ -157,29 +166,32 @@ def test_falcon_parallel_forward(model_name, world_size):
logits_ref = model_ref(input_ids).logits.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref del model_ref
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_generation(model_name): def test_falcon_generation(model_name):
"""Check that our implementation (with all optimizations enabled) matches the """Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, config = falcon_config_to_gpt2_config(
trust_remote_code=True)) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_mlp = False # We don't have fused MLP for "gelu" activation
...@@ -193,8 +205,9 @@ def test_falcon_generation(model_name): ...@@ -193,8 +205,9 @@ def test_falcon_generation(model_name):
batch_size = 1 batch_size = 1
seqlen = 100 seqlen = 100
max_length = 150 max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
...@@ -203,10 +216,11 @@ def test_falcon_generation(model_name): ...@@ -203,10 +216,11 @@ def test_falcon_generation(model_name):
print("HF fp16") print("HF fp16")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
...@@ -214,37 +228,49 @@ def test_falcon_generation(model_name): ...@@ -214,37 +228,49 @@ def test_falcon_generation(model_name):
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval() model.eval()
print('Without CUDA graph') print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
eos_token_id=eos_token_id, fused_ft_kernel=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True, max_length=max_length,
teacher_outputs=out_hf.sequences) eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph') print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=True, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True, max_length=max_length,
teacher_outputs=out_hf.sequences) fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad(): with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1) logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1) logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1)
...@@ -254,18 +280,18 @@ def test_falcon_generation(model_name): ...@@ -254,18 +280,18 @@ def test_falcon_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}') print(f"HF fp16 logits max diff: {hf_error}")
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation" # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32. # memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_generation(model_name, world_size): def test_falcon_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation: """Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...@@ -274,8 +300,9 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -274,8 +300,9 @@ def test_falcon_parallel_generation(model_name, world_size):
from apex.transformer import parallel_state from apex.transformer import parallel_state
dtype = torch.float16 dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, config = falcon_config_to_gpt2_config(
trust_remote_code=True)) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False config.use_flash_attn = False
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_mlp = False # We don't have fused MLP for "gelu" activation
...@@ -286,8 +313,8 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -286,8 +313,8 @@ def test_falcon_parallel_generation(model_name, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -297,36 +324,50 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -297,36 +324,50 @@ def test_falcon_parallel_generation(model_name, world_size):
batch_size = 1 batch_size = 1
seqlen = 100 seqlen = 100
max_length = 150 max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval() model.eval()
print('Without CUDA graph') print("Without CUDA graph")
out = model.generate( out = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, input_ids=input_ids,
vocab_size=config.vocab_size, fused_ft_kernel=True, max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True return_dict_in_generate=True,
output_scores=True,
timing=True,
) )
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph') print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, input_ids=input_ids,
vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True, max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True return_dict_in_generate=True,
output_scores=True,
timing=True,
) )
del model del model
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
...@@ -341,11 +382,13 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -341,11 +382,13 @@ def test_falcon_parallel_generation(model_name, world_size):
start = time.time() start = time.time()
with torch.inference_mode(): with torch.inference_mode():
out_hf = model_hf.generate( out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, input_ids=input_ids,
output_scores=True max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
...@@ -353,7 +396,7 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -353,7 +396,7 @@ def test_falcon_parallel_generation(model_name, world_size):
) )
model_ref.eval() model_ref.eval()
with torch.inference_mode(): with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1) logits_hf = torch.stack(out_hf.scores, dim=1)
...@@ -361,8 +404,8 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -361,8 +404,8 @@ def test_falcon_parallel_generation(model_name, world_size):
logits_cg = torch.stack(out_cg.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}') print(f"HF fp16 logits max diff: {hf_error}")
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
import re import re
import torch
import pytest import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name): def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
...@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name): ...@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name): def test_gpt2_non_optimized(model_name):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the """Check that our implementation of GPT2 (without any optimizations enabled) matches the
...@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name): ...@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name): def test_gpt2_optimized(model_name):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the """Check that our implementation of GPT2 (with all optimizations enabled) matches the
...@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name): ...@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits[..., :vocab_size_og] logits = model(input_ids).logits[..., :vocab_size_og]
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
...@@ -2,36 +2,32 @@ import os ...@@ -2,36 +2,32 @@ import os
import re import re
import time import time
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPT2Config, GPT2Tokenizer, OPTConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True]) # @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [False]) # @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False]) # @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
if rotary: if rotary:
...@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings. # if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test. # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, model = GPTLMHeadModel.from_pretrained(
dtype=dtype) model_name, config, strict=not rotary, device=device, dtype=dtype
)
model.eval() model.eval()
if not rotary: if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
torch_dtype=dtype).to(device=device) device=device
)
model_ref.eval() model_ref.eval()
model_hf.eval() model_hf.eval()
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he", input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 25 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True) return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=fused_ft_kernel, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences) print(out_cg.sequences)
if not rotary: if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids,
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, max_length=max_length,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True,
output_scores=True,
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') out_ref = model_ref.generate(
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') input_ids=input_ids,
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(tokenizer.batch_decode(out_ref.sequences.tolist())) print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary: if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]) @pytest.mark.parametrize(
"model_name",
[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name): def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation: """Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
print(f'\nMODEL: {model_name}') print(f"\nMODEL: {model_name}")
verbose = False verbose = False
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True) config.residual_in_fp32 = getattr(config, "prenorm", True)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
...@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name): ...@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and he", input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 25 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name): ...@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name):
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all(): if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
...@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name): ...@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name):
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
print('Without CUDA graph') print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose: if verbose:
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if fused_ft_kernel:
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length print("With CUDA graph")
)
print('With CUDA graph')
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=fused_ft_kernel, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose: if verbose:
print(out_cg.sequences) print(out_cg.sequences)
print(tokenizer.batch_decode(out_cg.sequences.tolist())) print(tokenizer.batch_decode(out_cg.sequences.tolist()))
...@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name): ...@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name):
print("HF fp16") print("HF fp16")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
...@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name): ...@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name):
print("HF fp32") print("HF fp32")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, out_ref = model_ref.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_ref del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist())) print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose: if verbose:
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
...@@ -2,34 +2,37 @@ import os ...@@ -2,34 +2,37 @@ import os
import re import re
import time import time
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from transformers import GPT2Config
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True, out = model.generate(
teacher_outputs=teacher_outputs, return_dict_in_generate=True, input_ids=input_ids,
output_scores=True, timing=True, **kwargs) max_length=max_length,
fused_ft_kernel=True,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1) return torch.stack(out.scores, dim=1)
@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"]) @pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None]) # @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph. """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
"""
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024 config.n_positions = 16 * 1024
...@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): ...@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 1 batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
...@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): ...@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
# Try increasing batch size and seqlen, then decrease them to see if it's still correct # Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3 batch_size = 3
maxlen += 30 maxlen += 30
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg) assert torch.equal(logits, logits_cg)
batch_size = 2 batch_size = 2
maxlen -= 35 maxlen -= 35
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg) assert torch.equal(logits, logits_cg)
...@@ -3,27 +3,23 @@ ...@@ -3,27 +3,23 @@
import os import os
import re import re
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) # @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize("fused_ft_kernel", [True])
# @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
from apex.transformer import parallel_state from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings. # if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test. # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, model = GPTLMHeadModel.from_pretrained(
dtype=dtype, process_group=process_group, model_name,
world_size=world_size, rank=rank) config,
strict=not rotary,
device=device,
dtype=dtype,
process_group=process_group,
world_size=world_size,
rank=rank,
)
model.eval() model.eval()
if not rotary: if not rotary:
...@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cur_input_ids = input_ids cur_input_ids = input_ids
with torch.inference_mode(): with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)', logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
b=input_ids.shape[0])[..., :config.vocab_size] ..., : config.vocab_size
]
scores.append(logits) scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)', logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
b=input_ids.shape[0])[..., :config.vocab_size] ..., : config.vocab_size
]
scores.append(logits) scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
print(sequences) print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, out = model.generate(
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences) print(out.sequences)
if fused_ft_kernel: if fused_ft_kernel:
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, input_ids=input_ids,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True, max_length=max_length,
return_dict_in_generate=True, output_scores=True, timing=True) tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences) print(out_cg.sequences)
if not rotary: if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids,
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, max_length=max_length,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True,
output_scores=True,
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') out_ref = model_ref.generate(
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') input_ids=input_ids,
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary: if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
...@@ -2,37 +2,37 @@ ...@@ -2,37 +2,37 @@
import time import time
import torch
import pytest import pytest
import torch
from transformers import GPTNeoXConfig, AutoTokenizer
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name): def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_gpt_neox(
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys(): for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gpt_neox_optimized(model_name): def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name): ...@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name):
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
logits = model(input_ids).logits logits = model(input_ids).logits
...@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name): ...@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name):
# Need at least 2 GPUs, otherwise we'll OOM # Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto') model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device) out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref del model_ref
model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTNeoXForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
del model_hf del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item() assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 2 * (
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item() logits_hf - logits_ref
).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (
logits_hf - logits_ref
).abs().mean().item()
...@@ -3,33 +3,29 @@ ...@@ -3,33 +3,29 @@
import math import math
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import GPT2Config
from apex.transformer import parallel_state from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
...@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
num_layers = 2 num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2) rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device) g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers, config = GPT2Config(
n_embd=dim,
n_head=num_heads,
n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0, n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, resid_pdrop=0.0,
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, embd_pdrop=0.0,
attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True, residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5, rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size, pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel) sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size) config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device) model_pt = GPTLMHeadModel(config, device=device)
...@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight) nn.init.normal_(module.weight)
nn.init.normal_(module.bias) nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm) model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device) model = GPTLMHeadModel(config, process_group=process_group, device=device)
...@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
) )
shared_nparams = sum(p.numel() for p in model.parameters() shared_nparams = sum(
if getattr(p, '_shared_params', False)) p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
) )
assert torch.all(shared_nparams_all == shared_nparams) assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item() assert total_nparams == (
+ shared_nparams) (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here # vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size partition_vocab_size = config.vocab_size // world_size
...@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights() model.tie_weights()
with torch.autocast(device_type='cuda', dtype=dtype): with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits out = model(input_ids[:, :-1]).logits
if not sequence_parallel: if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d') out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size], out,
rtol=rtol, atol=atol out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
) )
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group) loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none') loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten()) loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten()) loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol) assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
...@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
allreduce_sequence_parallel_grad(model, process_group) allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()}, grad_dict = shard_state_dict_tp(
config, world_size, rank) {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose( assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad, model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict['transformer.embeddings.word_embeddings.weight'], grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
if has_pos_emb: if has_pos_emb:
assert torch.allclose( assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad, model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict['transformer.embeddings.position_embeddings.weight'], grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol, atol=atol rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.weight.grad,
grad_dict["transformer.ln_f.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
) )
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
rtol=rtol, atol=atol)
for i in range(num_layers): for i in range(num_layers):
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad, model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'], grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad, model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'], grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad, model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'], grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'], model.transformer.layers[i].mixer.out_proj.bias.grad,
rtol=rtol, atol=atol * 5) grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad, model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'], grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad, model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'], grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad, model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'], grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'], model.transformer.layers[i].mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5) grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, atol=atol * 5,
grad_dict[f'transformer.layers.{i}.norm1.weight'], )
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.norm1.bias'], model.transformer.layers[i].norm1.weight.grad,
rtol=rtol, atol=atol) grad_dict[f"transformer.layers.{i}.norm1.weight"],
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, rtol=rtol,
grad_dict[f'transformer.layers.{i}.norm2.weight'], atol=atol,
rtol=rtol, atol=atol) )
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.norm2.bias'], model.transformer.layers[i].norm1.bias.grad,
rtol=rtol, atol=atol) grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
...@@ -2,37 +2,35 @@ ...@@ -2,37 +2,35 @@
import time import time
import torch
import pytest import pytest
import torch
from transformers import GPTJConfig, AutoTokenizer
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name): def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys(): for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name): def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name): ...@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
logits = model(input_ids).logits logits = model(input_ids).logits
...@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name): ...@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name):
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
del model_ref del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTJForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
del model_hf del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name): def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -104,56 +106,71 @@ def test_gptj_generation(model_name): ...@@ -104,56 +106,71 @@ def test_gptj_generation(model_name):
batch_size = 1 batch_size = 1
seqlen = 100 seqlen = 100
max_length = 150 max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTJForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
print("HF fp16") print("HF fp16")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval() model.eval()
print('Without CUDA graph') print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
eos_token_id=eos_token_id, fused_ft_kernel=True, input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False, # eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True, output_scores=True, timing=True, return_dict_in_generate=True,
teacher_outputs=out_hf.sequences) output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph') print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=True, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True, max_length=max_length,
teacher_outputs=out_hf.sequences) fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad(): with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1) logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1) logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1)
...@@ -163,8 +180,8 @@ def test_gptj_generation(model_name): ...@@ -163,8 +180,8 @@ def test_gptj_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}') print(f"HF fp16 logits max diff: {hf_error}")
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
...@@ -11,26 +11,25 @@ from pathlib import Path ...@@ -11,26 +11,25 @@ from pathlib import Path
current_dir = Path(__file__).parent.absolute() current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import shutil import shutil
import pytest
import torch
from einops import rearrange from einops import rearrange
from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import ( from flash_attn.models.llama import (
remap_state_dict_meta_llama, config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config, llama_config_to_gpt2_config,
remap_state_dict_hf_llama, remap_state_dict_hf_llama,
inv_remap_state_dict_hf_llama, remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
) )
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
......
import re import re
import torch
import pytest import pytest
import torch
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) @pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name): def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
...@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name): ...@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) @pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name): def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without all optimizations enabled) matches the """Check that our implementation of OPT (without all optimizations enabled) matches the
...@@ -31,14 +33,14 @@ def test_opt_optimized(model_name): ...@@ -31,14 +33,14 @@ def test_opt_optimized(model_name):
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True) config.residual_in_fp32 = getattr(config, "prenorm", True)
config.pad_vocab_size_multiple = 8 config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
...@@ -53,26 +55,29 @@ def test_opt_optimized(model_name): ...@@ -53,26 +55,29 @@ def test_opt_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512 )
if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state out_ref = model_ref.model(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
import re import re
import torch
import pytest import pytest
import torch
from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
from timm.models.vision_transformer import vit_base_patch16_224
@pytest.mark.parametrize('fused_mlp', [False, True]) @pytest.mark.parametrize("fused_mlp", [False, True])
# @pytest.mark.parametrize('fused_mlp', [False]) # @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_mlp): def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation: """Check that our implementation of ViT matches the timm's implementation:
...@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp): ...@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp):
timm' forward pass in fp16, when compared to timm's forward pass in fp32. timm' forward pass in fp16, when compared to timm's forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
kwargs = {} kwargs = {}
if optimized: if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_mlp'] = fused_mlp kwargs["fused_mlp"] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device) model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
...@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp): ...@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp):
out_timm = model_timm(x) out_timm = model_timm(x)
out_ref = model_ref(x.float()) out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
rtol = 2 if not fused_mlp else 8 rtol = 2 if not fused_mlp else 8
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
...@@ -4,31 +4,27 @@ ...@@ -4,31 +4,27 @@
import math import math
from functools import partial from functools import partial
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from flash_attn.modules.block import Block
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [True]) # @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype): def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
...@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
assert num_heads % world_size == 0 assert num_heads % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
requires_grad=True)
residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True) residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_() tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
residual = (
tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_() residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), mixer_cls_pt = partial(
use_flash_attn=True, device=device, dtype=dtype) MHA,
num_heads=num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype) mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype) norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True) model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
...@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
nn.init.normal_(model_pt.norm2.weight) nn.init.normal_(model_pt.norm2.weight)
nn.init.normal_(model_pt.norm2.bias) nn.init.normal_(model_pt.norm2.bias)
mixer_cls = partial(ParallelMHA, num_heads=num_heads, mixer_cls = partial(
ParallelMHA,
num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, rotary_emb_dim=int(head_dim // 2),
sequence_parallel=sequence_parallel, device=device, dtype=dtype) use_flash_attn=True,
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim, sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
mlp_cls = partial(
ParallelFusedMLP,
hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel,
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, device=device,
sequence_parallel=sequence_parallel, mark_shared_params=True) dtype=dtype,
)
model = Block(
dim,
mixer_cls,
mlp_cls,
norm_cls,
fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel,
mark_shared_params=True,
)
partition_dim = dim // world_size partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size partition_hidden_dim = 4 * dim // world_size
with torch.no_grad(): with torch.no_grad():
model.mixer.Wqkv.weight.copy_( model.mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i') rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
) )
model.mixer.Wqkv.bias.copy_( model.mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)') rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
) )
model.mixer.out_proj.weight.copy_( model.mixer.out_proj.weight.copy_(
model_pt.mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias) model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
model.mlp.fc1.weight.copy_( model.mlp.fc1.weight.copy_(
model_pt.mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
) )
model.mlp.fc1.bias.copy_( model.mlp.fc1.bias.copy_(
model_pt.mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
) )
model.mlp.fc2.weight.copy_( model.mlp.fc2.weight.copy_(
model_pt.mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc2.weight[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
]
) )
if rank == 0: if rank == 0:
model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias) model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
...@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
model.norm2.weight.copy_(model_pt.norm2.weight) model.norm2.weight.copy_(model_pt.norm2.weight)
model.norm2.bias.copy_(model_pt.norm2.bias) model.norm2.bias.copy_(model_pt.norm2.bias)
mixer_kwargs = {'seqlen': seqlen} mixer_kwargs = {"seqlen": seqlen}
out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs) out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
out_pt, out_residual_pt = model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen), out_pt, out_residual_pt = model_pt(
rearrange(residual_pt, '(b s) d -> b s d', s=seqlen)) rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]] rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
)
out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
out_residual, out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt, if sequence_parallel
rtol=rtol, atol=atol else out_residual_pt,
rtol=rtol,
atol=atol,
) )
(out_pt + 2 * out_residual_pt).backward(g) (out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] (out + 2 * out_residual).backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group()) allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small else x_pt.grad,
rtol=rtol,
atol=atol / 10, # magnitude of x.grad is quite small
) )
assert torch.allclose( assert torch.allclose(
residual.grad, residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else residual_pt.grad,
rtol=rtol,
atol=atol,
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.mixer.Wqkv.weight.grad, model.mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i'), rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
rtol=rtol, atol=atol * 10 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.mixer.Wqkv.bias.grad, model.mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)'), rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
rtol=rtol, atol=atol * 5 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.mixer.out_proj.weight.grad, model.mixer.out_proj.weight.grad,
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.mixer.out_proj.bias.grad, model_pt.mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.mixer.out_proj.bias.grad,
model_pt.mixer.out_proj.bias.grad,
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose( assert torch.allclose(
model.mlp.fc1.weight.grad, model.mlp.fc1.weight.grad,
model_pt.mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc1.weight.grad[
rtol=rtol, atol=atol * 10 rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.mlp.fc1.bias.grad, model.mlp.fc1.bias.grad,
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.mlp.fc2.weight.grad, model.mlp.fc2.weight.grad,
model_pt.mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc2.weight.grad[
rtol=rtol, atol=atol * 10 :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, assert torch.allclose(
rtol=rtol, atol=atol * 5) model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with: # Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264 vocab_size = 50264
seqlen = 2048 seqlen = 2048
...@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty ...@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
assert dim % world_size == 0 assert dim % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty ...@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
input_ids = input_ids_pt.detach().clone() input_ids = input_ids_pt.detach().clone()
model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, model_pt = GPT2Embeddings(
device=device, dtype=dtype) dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, )
model = ParallelGPT2Embeddings(
dim,
vocab_size,
seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(), parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
partition_vocab_size = vocab_size // world_size partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size partition_dim = dim // world_size
with torch.no_grad(): with torch.no_grad():
model.word_embeddings.weight.copy_( model.word_embeddings.weight.copy_(
model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size] model_pt.word_embeddings.weight[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
]
) )
if has_pos_emb: if has_pos_emb:
model.position_embeddings.weight.copy_( model.position_embeddings.weight.copy_(
model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.position_embeddings.weight[
:, rank * partition_dim : (rank + 1) * partition_dim
]
) )
out = model(input_ids, combine_batch_seqlen_dim=True) out = model(input_ids, combine_batch_seqlen_dim=True)
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
g = torch.randn_like(out_pt) g = torch.randn_like(out_pt)
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
model.word_embeddings.weight.grad, model.word_embeddings.weight.grad,
model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size], model_pt.word_embeddings.weight.grad[
rtol=rtol, atol=atol rank * partition_vocab_size : (rank + 1) * partition_vocab_size
],
rtol=rtol,
atol=atol,
) )
if has_pos_emb: if has_pos_emb:
assert torch.allclose( assert torch.allclose(
model.position_embeddings.weight.grad, model.position_embeddings.weight.grad,
model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.position_embeddings.weight.grad[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
rtol=rtol,
atol=atol,
) )
...@@ -3,29 +3,25 @@ ...@@ -3,29 +3,25 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128]) @pytest.mark.parametrize("head_dim", [64, 128])
# @pytest.mark.parametrize('head_dim', [64]) # @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096]) @pytest.mark.parametrize("embed_dim", [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024]) # @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype): def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0 assert embed_dim % head_dim == 0
...@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype) ...@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
assert num_heads % world_size == 0 assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype) ...@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), model_pt = MHA(
use_flash_attn=True, device=device, dtype=dtype) embed_dim,
num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
partition_dim = embed_dim // world_size partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), model = ParallelMHA(
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, embed_dim,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) num_heads,
parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.Wqkv.weight.copy_( model.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i') rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
) )
model.Wqkv.bias.copy_( model.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)') rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
) )
model.out_proj.weight.copy_( model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias) model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen) out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d') out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small else x_pt.grad,
rtol=rtol,
atol=atol / 100, # magnitude of x.grad is quite small
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.Wqkv.weight.grad, model.Wqkv.weight.grad,
rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i'), rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
rtol=rtol, atol=atol * 10 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.Wqkv.bias.grad, model.Wqkv.bias.grad,
rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)'), rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
rtol=rtol, atol=atol * 5 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.out_proj.weight.grad, model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
)
# Run test with: # Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('activation', [F.silu, F.sigmoid]) @pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu]) # @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize('dim', [1024, 4096]) @pytest.mark.parametrize("dim", [1024, 4096])
# @pytest.mark.parametrize('dim', [1024]) # @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): ...@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
requires_grad=True)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype) model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(), model = ParallelGatedMlp(
dim,
parallel_state.get_tensor_model_parallel_group(),
activation=activation, activation=activation,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_( model.fc1.weight.copy_(
rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o i -> (two o) i') rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
)
) )
model.fc1.bias.copy_( model.fc1.bias.copy_(
rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o -> (two o)') rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
)
) )
model.fc2.weight.copy_( model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias) model.fc2.bias.copy_(model_pt.fc2.bias)
...@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): ...@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else x_pt.grad,
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc1.weight.grad, model.fc1.weight.grad,
rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o i -> (two o) i'), rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
),
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc1.bias.grad, model.fc1.bias.grad,
rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o -> (two o)'), rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
),
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc2.weight.grad, model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol rtol=rtol,
atol=atol,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol) assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn.ops.layer_norm import (
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm DropoutAddLayerNorm,
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset dropout_add_layer_norm,
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm dropout_add_layer_norm_parallel_residual,
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset dropout_add_layer_norm_subset,
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual )
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual from flash_attn.ops.rms_norm import (
DropoutAddRMSNorm,
dropout_add_rms_norm,
dropout_add_rms_norm_parallel_residual,
dropout_add_rms_norm_subset,
)
try: try:
from apex.normalization import FusedRMSNorm from apex.normalization import FusedRMSNorm
...@@ -20,28 +24,42 @@ except: ...@@ -20,28 +24,42 @@ except:
FusedRMSNorm, fused_rms_norm_affine = None, None FusedRMSNorm, fused_rms_norm_affine = None, None
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('is_rms_norm', [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize("has_colscale", [True, False])
# @pytest.mark.parametrize('has_colscale', [False]) # @pytest.mark.parametrize('has_colscale', [False])
@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize("has_rowscale", [True, False])
# @pytest.mark.parametrize('has_rowscale', [True]) # @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False]) # @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_training(
dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm): hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None: if is_rms_norm and FusedRMSNorm is None:
...@@ -49,15 +67,16 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -49,15 +67,16 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale: if has_colscale:
...@@ -76,8 +95,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -76,8 +95,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87 survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else: else:
rowscale = None rowscale = None
x0_scaled_pt = x0_pt x0_scaled_pt = x0_pt
...@@ -98,16 +117,29 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -98,16 +117,29 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model.bias.copy_(model_pt.bias) model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p, out, dmask = our_layer_norm_func(
model.eps, rowscale=rowscale, layerscale=colscale, x0,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True) res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out.dtype == input_dtype assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual: if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype) residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else: else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref) out_ref = model_ref(residual_ref)
...@@ -119,24 +151,33 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -119,24 +151,33 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
out_ref.backward(g) out_ref.backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5 res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 3e-5
if not is_rms_norm: if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 3e-5
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) )
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37 dropout_p = 0.37
...@@ -144,8 +185,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -144,8 +185,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 32 batch_size = 32
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
...@@ -172,27 +214,39 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -172,27 +214,39 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('is_rms_norm', [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize("has_rowscale", [True, False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False]) # @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_prenorm_training(
dropout_p, has_residual, has_rowscale, has_colscale, hidden_size,
is_rms_norm): input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None: if is_rms_norm and FusedRMSNorm is None:
...@@ -200,15 +254,16 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -200,15 +254,16 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4) rtol, atol = (1e-3, 2e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale: if has_colscale:
...@@ -227,8 +282,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -227,8 +282,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87 survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else: else:
rowscale = None rowscale = None
x0_scaled_pt = x0_pt x0_scaled_pt = x0_pt
...@@ -241,8 +296,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -241,8 +296,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if not is_rms_norm: if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias) torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32) model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device, model = our_layer_norm_cls(
dtype=weight_dtype) hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model_ref.weight.copy_(model_pt.weight) model_ref.weight.copy_(model_pt.weight)
...@@ -250,24 +306,38 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -250,24 +306,38 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model.bias.copy_(model_pt.bias) model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p, out, residual, dmask = our_layer_norm_func(
model.eps, rowscale=rowscale, x0,
layerscale=colscale, prenorm=True, res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
prenorm=True,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True) return_dropout_mask=True,
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') )
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual: if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype) residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else: else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref) out_ref = model_ref(residual_ref)
assert out.dtype == input_dtype assert out.dtype == input_dtype
assert residual.dtype == residual_dtype assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt)).backward(g) (out_pt * F.sigmoid(residual_pt)).backward(g)
...@@ -275,24 +345,33 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -275,24 +345,33 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
if not is_rms_norm: if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) )
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37 dropout_p = 0.37
...@@ -300,8 +379,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp ...@@ -300,8 +379,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 32 batch_size = 32
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
...@@ -310,8 +390,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp ...@@ -310,8 +390,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model = DropoutAddLayerNorm(
dtype=weight_dtype) hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
...@@ -327,30 +408,36 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp ...@@ -327,30 +408,36 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref) out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training( def test_dropout_layer_norm_subset_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
has_residual, has_colscale): ):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4) rtol, atol = (1e-3, 2e-4)
# set seed # set seed
...@@ -359,23 +446,28 @@ def test_dropout_layer_norm_subset_training( ...@@ -359,23 +446,28 @@ def test_dropout_layer_norm_subset_training(
seqlen = 512 seqlen = 512
drop_path_rate = 0.4 drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate) drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True) mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen) mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0) ~mask_batch_seqlen, 0
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size) )
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device) x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen, batch_size, seqlen, drop_path_rate, device
drop_path_rate, device) )
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, batch_size, seqlen, drop_path_rate, device
requires_grad=True) )
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale: if has_colscale:
...@@ -402,8 +494,9 @@ def test_dropout_layer_norm_subset_training( ...@@ -402,8 +494,9 @@ def test_dropout_layer_norm_subset_training(
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device, model = DropoutAddLayerNorm(
dtype=weight_dtype) hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias) model.bias.copy_(model_pt.bias)
...@@ -412,25 +505,42 @@ def test_dropout_layer_norm_subset_training( ...@@ -412,25 +505,42 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset( out, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, res,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32, model.weight,
return_dropout_mask=True) model.bias,
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') model.p,
model.eps,
x0_scaled_pt = x0_scaled_pt.masked_fill( layerscale=colscale,
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 x0_subset=x0_subset,
) * drop_path_scale out_subset=out_subset,
x0_scaled_ref = x0_scaled_ref.masked_fill( rowscale_const=drop_path_scale,
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 out_numrows=out_numrows,
) * drop_path_scale prenorm=False,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask dmask_expanded[x0_mask_batch] = dmask
if has_residual: if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype) residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else: else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch] out_ref = model_ref(residual_ref)[out_mask_batch]
...@@ -441,36 +551,50 @@ def test_dropout_layer_norm_subset_training( ...@@ -441,36 +551,50 @@ def test_dropout_layer_norm_subset_training(
out_pt.backward(g) out_pt.backward(g)
out.backward(g) out.backward(g)
out_ref.backward(g) out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4 assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 res_pt.grad - res_ref.grad
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 ).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training( def test_dropout_layer_norm_subset_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
has_residual, has_colscale): ):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
device = 'cuda' device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4) rtol, atol = (1e-3, 2e-4)
# set seed # set seed
...@@ -479,23 +603,28 @@ def test_dropout_layer_norm_subset_prenorm_training( ...@@ -479,23 +603,28 @@ def test_dropout_layer_norm_subset_prenorm_training(
seqlen = 512 seqlen = 512
drop_path_rate = 0.4 drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate) drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True) mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen) mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0) ~mask_batch_seqlen, 0
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size) )
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device) x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen, batch_size, seqlen, drop_path_rate, device
drop_path_rate, device) )
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, batch_size, seqlen, drop_path_rate, device
requires_grad=True) )
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale: if has_colscale:
...@@ -522,8 +651,9 @@ def test_dropout_layer_norm_subset_prenorm_training( ...@@ -522,8 +651,9 @@ def test_dropout_layer_norm_subset_prenorm_training(
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model = DropoutAddLayerNorm(
dtype=weight_dtype) hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias) model.bias.copy_(model_pt.bias)
...@@ -532,89 +662,139 @@ def test_dropout_layer_norm_subset_prenorm_training( ...@@ -532,89 +662,139 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset( out, residual, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, res,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32, model.weight,
return_dropout_mask=True) model.bias,
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') model.p,
model.eps,
x0_scaled_pt = x0_scaled_pt.masked_fill( layerscale=colscale,
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 x0_subset=x0_subset,
) * drop_path_scale out_subset=out_subset,
x0_scaled_ref = x0_scaled_ref.masked_fill( rowscale_const=drop_path_scale,
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 out_numrows=out_numrows,
) * drop_path_scale prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask dmask_expanded[x0_mask_batch] = dmask
if has_residual: if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype) residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else: else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch] out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype assert out.dtype == input_dtype
assert residual.dtype == residual_dtype assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g) (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(
g
)
(out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g) (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
(out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g) (
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4 out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype))
+ residual_ref.mean(0, keepdim=True)
).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 res_pt.grad - res_ref.grad
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 ).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('is_rms_norm', [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False]) # @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True]) @pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False]) # @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False]) # @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False]) @pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True]) # @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_training( def test_dropout_layer_norm_parallel_residual_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, hidden_size,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
): ):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None: if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm our_layer_norm_func = (
else dropout_add_rms_norm_parallel_residual) dropout_add_layer_norm_parallel_residual
device = 'cuda' if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1: if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x1_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_()
else: else:
...@@ -626,16 +806,22 @@ def test_dropout_layer_norm_parallel_residual_training( ...@@ -626,16 +806,22 @@ def test_dropout_layer_norm_parallel_residual_training(
else: else:
res = None res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias0 = (
if not is_rms_norm else None) torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_() weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_() weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm: if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias1 = (
if not is_rms_norm else None) torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_() weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_() weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
...@@ -646,48 +832,77 @@ def test_dropout_layer_norm_parallel_residual_training( ...@@ -646,48 +832,77 @@ def test_dropout_layer_norm_parallel_residual_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, dmask0, dmask1 = our_layer_norm_func( out0, out1, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p, x0,
epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
) )
assert out0.dtype == input_dtype assert out0.dtype == input_dtype
if not tied_norm: if not tied_norm:
assert out1.dtype == input_dtype assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}') print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual: if has_residual:
if has_x1: if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype) + res_pt.float()
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p) ).to(dtype=residual_dtype)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else: else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
+ res_pt.float()).to(dtype=residual_dtype) dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else: else:
if has_x1: if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = (
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype) (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) ).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else: else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm: if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, out0_pt = F.layer_norm(
eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm: if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt, out1_pt = F.layer_norm(
bias1_pt, eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype),
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon) (hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else: else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), out0_pt = fused_rms_norm_affine(
eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm: if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt, out1_pt = fused_rms_norm_affine(
(hidden_size,), eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
...@@ -706,61 +921,89 @@ def test_dropout_layer_norm_parallel_residual_training( ...@@ -706,61 +921,89 @@ def test_dropout_layer_norm_parallel_residual_training(
(out0_ref * g0 + out1_ref * g1).sum().backward() (out0_ref * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1: if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5 res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm: if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5 assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm: if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5 assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm: if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5 assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
@pytest.mark.parametrize('is_rms_norm', [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False]) # @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True]) @pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False]) # @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False]) # @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False]) @pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True]) # @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize(
[(torch.float16, torch.float16), (torch.float16, torch.float32), "input_dtype,residual_dtype",
(torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256]) # @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_prenorm_training( def test_dropout_layer_norm_parallel_residual_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, hidden_size,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
): ):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None: if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm our_layer_norm_func = (
else dropout_add_rms_norm_parallel_residual) dropout_add_layer_norm_parallel_residual
device = 'cuda' if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x0_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1: if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, x1_pt = torch.randn(
requires_grad=True) batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_()
else: else:
...@@ -772,16 +1015,22 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training( ...@@ -772,16 +1015,22 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
else: else:
res = None res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias0 = (
if not is_rms_norm else None) torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_() weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_() weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm: if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias1 = (
if not is_rms_norm else None) torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_() weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_() weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
...@@ -792,54 +1041,86 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training( ...@@ -792,54 +1041,86 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, residual, dmask0, dmask1 = our_layer_norm_func( out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p, x0,
epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
) )
assert out0.dtype == input_dtype assert out0.dtype == input_dtype
if not tied_norm: if not tied_norm:
assert out1.dtype == input_dtype assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}') print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual: if has_residual:
if has_x1: if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype) + res_pt.float()
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p) ).to(dtype=residual_dtype)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else: else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
+ res_pt.float()).to(dtype=residual_dtype) dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else: else:
if has_x1: if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) residual_pt = (
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype) (x0_pt.float() * dmask0.float()) / (1 - dropout_p)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) ).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else: else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype) residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm: if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, out0_pt = F.layer_norm(
eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm: if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt, out1_pt = F.layer_norm(
bias1_pt, eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype),
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon) (hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else: else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), out0_pt = fused_rms_norm_affine(
eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm: if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt, out1_pt = fused_rms_norm_affine(
(hidden_size,), eps=epsilon).to(dtype=input_dtype) residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm: if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4 assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size g0 = torch.randn_like(out0) / batch_size
if tied_norm: if tied_norm:
...@@ -853,39 +1134,56 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training( ...@@ -853,39 +1134,56 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
(out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward() (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1: if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual: if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4 assert (res.grad - res_ref.grad).abs().max() <= 4 * (
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5 res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm: if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5 assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm: if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5 assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm: if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5 assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
def test_dropout_layer_norm_randomness(): def test_dropout_layer_norm_randomness():
hidden_size = 256 hidden_size = 256
dtype = torch.float32 dtype = torch.float32
dropout_p = 0.1 dropout_p = 0.1
device = 'cuda' device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x0 = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True) x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True
)
res = torch.randn_like(x0, dtype=dtype, requires_grad=True) res = torch.randn_like(x0, dtype=dtype, requires_grad=True)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype) model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)
torch.random.manual_seed(42) torch.random.manual_seed(42)
_, dmask0 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p, _, dmask0 = dropout_add_layer_norm(
model.eps, return_dropout_mask=True) x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
# Subsequent call should have a different dropout mask # Subsequent call should have a different dropout mask
_, dmask1 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p, _, dmask1 = dropout_add_layer_norm(
model.eps, return_dropout_mask=True) x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
torch.random.manual_seed(42) torch.random.manual_seed(42)
# Resetting the seed, should get the same dropout mask # Resetting the seed, should get the same dropout mask
_, dmask2 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p, _, dmask2 = dropout_add_layer_norm(
model.eps, return_dropout_mask=True) x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
assert not torch.equal(dmask0, dmask1) assert not torch.equal(dmask0, dmask1)
assert torch.equal(dmask0, dmask2) assert torch.equal(dmask0, dmask2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment