Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
...@@ -2,26 +2,24 @@ ...@@ -2,26 +2,24 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox from transformers.models.gpt_neox.modeling_gpt_neox import (
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
# NeoX-style rotary embedding # NeoX-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711]) @pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset): def test_rotary(rotary_emb_fraction, seqlen_offset):
device = 'cuda' device = "cuda"
dtype = torch.float16 dtype = torch.float16
rtol, atol = (1e-3, 5e-3) rtol, atol = (1e-3, 5e-3)
# set seed # set seed
...@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset): ...@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset):
nheads = 16 nheads = 16
headdim = 128 headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction) rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device) rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device) rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total) cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype) cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim], q_pt = (
'b s h d -> b h s d').detach().clone().requires_grad_(True) rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim], .detach()
'b s h d -> b h s d').detach().clone().requires_grad_(True) .clone()
.requires_grad_(True)
)
k_pt = (
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset) q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset) out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype), assert torch.allclose(
rtol=rtol, atol=atol) rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype), )
rtol=rtol, atol=atol) assert torch.allclose(
assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim], rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
rtol=rtol, atol=atol) )
assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim], assert torch.allclose(
rtol=rtol, atol=atol) rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out) g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g) out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d')) q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d')) k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'), assert torch.allclose(
qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) rearrange(q_pt.grad, "b h s d -> b s h d"),
assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'), qkv.grad[:, :, 0, :, :rotary_dim],
qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) rtol=rtol,
atol=atol,
)
assert torch.allclose(
rearrange(k_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 1, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2]) assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding # GPT-J-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711]) @pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = 'cuda' device = "cuda"
dtype = torch.float16 dtype = torch.float16
rtol, atol = (1e-3, 5e-3) rtol, atol = (1e-3, 5e-3)
# set seed # set seed
...@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): ...@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
nheads = 16 nheads = 16
headdim = 128 headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction) rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, qkv = torch.randn(
requires_grad=True) batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device) rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total) sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
......
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLossApex from flash_attn.losses.cross_entropy import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True]) @pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False]) # @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
@pytest.mark.parametrize('vocab_size', [50257]) @pytest.mark.parametrize("vocab_size", [50257])
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
device = 'cuda' device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 128 seqlen = 128
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True) x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100 y[torch.randperm(batch_size * seqlen)[:10]] = -100
......
...@@ -3,35 +3,37 @@ ...@@ -3,35 +3,37 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True]) @pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False]) # @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9]) # @pytest.mark.parametrize('smoothing', [0.9])
@pytest.mark.parametrize('vocab_size', [50264]) @pytest.mark.parametrize("vocab_size", [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype): def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 rtol, atol = (
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) (1e-5, 1e-6)
if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
partition_vocab_size = vocab_size // world_size partition_vocab_size = vocab_size // world_size
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_ ...@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 128 seqlen = 128
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device, x_pt = (
dtype=dtype) * 10).requires_grad_() torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_() ).requires_grad_()
x = (
tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100 y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none') model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none', model = CrossEntropyLoss(
inplace_backward=inplace_backward, label_smoothing=smoothing,
process_group=parallel_state.get_tensor_model_parallel_group()) reduction="none",
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y) out = model(x, y)
out_pt = model_pt(x_pt.float(), y) out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
...@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_ ...@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
g = torch.randn_like(out) g = torch.randn_like(out)
out_pt.backward(g) out_pt.backward(g)
out.backward(g) out.backward(g)
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol) assert torch.allclose(
x.grad,
x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
import re import re
from collections import OrderedDict from collections import OrderedDict
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name): def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
...@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name): ...@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name):
def get_hf_models(model_name, config, dtype): def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name) pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key): def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key) key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key) key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items()) pretrained_state_dict = OrderedDict(
(key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
)
model_hf = BertForPreTrainingHF(config) model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias" # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias. # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
...@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype): ...@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name): def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the """Check that our implementation of BERT (without any optimizations enabled) matches the
...@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name): ...@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
...@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name): ...@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name):
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask) out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}') print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}') print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}') print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item() assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item() sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name): def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the """Check that our implementation of BERT (with all optimizations enabled) matches the
...@@ -117,10 +121,11 @@ def test_bert_optimized(model_name): ...@@ -117,10 +121,11 @@ def test_bert_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
...@@ -131,12 +136,24 @@ def test_bert_optimized(model_name): ...@@ -131,12 +136,24 @@ def test_bert_optimized(model_name):
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0 sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}') print(
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}') f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}') )
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}') print(
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item() f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item() )
print(
f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
)
print(
f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
)
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
out = model(input_ids, attention_mask=attention_mask) out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
...@@ -144,25 +161,43 @@ def test_bert_optimized(model_name): ...@@ -144,25 +161,43 @@ def test_bert_optimized(model_name):
prediction_scores = prediction_scores.clone() prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0 prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask) out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf[~attention_mask, :] = 0.0 prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask) out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref[~attention_mask, :] = 0.0 prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}') print(
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}') f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}') )
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}') print(
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item() f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item() )
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True]) @pytest.mark.parametrize("last_layer_subset", [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True]) # @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False]) @pytest.mark.parametrize("has_key_padding_mask", [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True]) # @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset): def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the """Check that our implementation of BERT (with all optimizations enabled) matches the
...@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs ...@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
if has_key_padding_mask: if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
else: else:
attention_mask = None attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, )
device='cuda') labels = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if attention_mask is not None: if attention_mask is not None:
labels[~attention_mask] = 0 labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0 labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0 masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda') next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
out = model( out = model(
input_ids, attention_mask=attention_mask, input_ids,
labels=labels, next_sentence_label=next_sequence_label attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
) )
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask, out_hf = model_hf(
labels=labels, next_sentence_label=next_sequence_label) input_ids,
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits attention_mask=attention_mask,
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask] labels=labels,
out_ref = model_ref(input_ids, attention_mask=attention_mask, next_sentence_label=next_sequence_label,
labels=labels, next_sentence_label=next_sequence_label) )
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits prediction_scores_hf, seq_relationship_scores_hf = (
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask] out_hf.prediction_logits,
out_hf.seq_relationship_logits,
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}') )
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}') prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}') out_ref = model_ref(
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}') input_ids,
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item() attention_mask=attention_mask,
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item() labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0. # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item() # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
This diff is collapsed.
import re import re
import torch
import pytest import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name): def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
...@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name): ...@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name): def test_gpt2_non_optimized(model_name):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the """Check that our implementation of GPT2 (without any optimizations enabled) matches the
...@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name): ...@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name): def test_gpt2_optimized(model_name):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the """Check that our implementation of GPT2 (with all optimizations enabled) matches the
...@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name): ...@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits[..., :vocab_size_og] logits = model(input_ids).logits[..., :vocab_size_og]
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
...@@ -2,36 +2,32 @@ import os ...@@ -2,36 +2,32 @@ import os
import re import re
import time import time
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPT2Config, GPT2Tokenizer, OPTConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True]) # @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [False]) # @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False]) # @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
if rotary: if rotary:
...@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings. # if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test. # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, model = GPTLMHeadModel.from_pretrained(
dtype=dtype) model_name, config, strict=not rotary, device=device, dtype=dtype
)
model.eval() model.eval()
if not rotary: if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
torch_dtype=dtype).to(device=device) device=device
)
model_ref.eval() model_ref.eval()
model_hf.eval() model_hf.eval()
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he", input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 25 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
fused_ft_kernel=fused_ft_kernel, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=fused_ft_kernel, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences) print(out_cg.sequences)
if not rotary: if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids,
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, max_length=max_length,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True,
output_scores=True,
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') out_ref = model_ref.generate(
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') input_ids=input_ids,
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(tokenizer.batch_decode(out_ref.sequences.tolist())) print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary: if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]) @pytest.mark.parametrize(
"model_name",
[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name): def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation: """Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
print(f'\nMODEL: {model_name}') print(f"\nMODEL: {model_name}")
verbose = False verbose = False
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True) config.residual_in_fp32 = getattr(config, "prenorm", True)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
...@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name): ...@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and he", input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 25 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name): ...@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name):
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1]) scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all(): if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
...@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name): ...@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name):
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
print('Without CUDA graph') print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose: if verbose:
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if fused_ft_kernel:
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length print("With CUDA graph")
)
print('With CUDA graph')
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=fused_ft_kernel, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose: if verbose:
print(out_cg.sequences) print(out_cg.sequences)
print(tokenizer.batch_decode(out_cg.sequences.tolist())) print(tokenizer.batch_decode(out_cg.sequences.tolist()))
...@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name): ...@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name):
print("HF fp16") print("HF fp16")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
...@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name): ...@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name):
print("HF fp32") print("HF fp32")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, out_ref = model_ref.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_ref del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist())) print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose: if verbose:
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
...@@ -2,34 +2,37 @@ import os ...@@ -2,34 +2,37 @@ import os
import re import re
import time import time
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from transformers import GPT2Config
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True, out = model.generate(
teacher_outputs=teacher_outputs, return_dict_in_generate=True, input_ids=input_ids,
output_scores=True, timing=True, **kwargs) max_length=max_length,
fused_ft_kernel=True,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1) return torch.stack(out.scores, dim=1)
@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"]) @pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None]) # @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph. """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
"""
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024 config.n_positions = 16 * 1024
...@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): ...@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 1 batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
...@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): ...@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
# Try increasing batch size and seqlen, then decrease them to see if it's still correct # Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3 batch_size = 3
maxlen += 30 maxlen += 30
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg) assert torch.equal(logits, logits_cg)
batch_size = 2 batch_size = 2
maxlen -= 35 maxlen -= 35
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, )
device=device) teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg) assert torch.equal(logits, logits_cg)
...@@ -3,27 +3,23 @@ ...@@ -3,27 +3,23 @@
import os import os
import re import re
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) # @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize("fused_ft_kernel", [True])
# @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
from apex.transformer import parallel_state from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings. # if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test. # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, model = GPTLMHeadModel.from_pretrained(
dtype=dtype, process_group=process_group, model_name,
world_size=world_size, rank=rank) config,
strict=not rotary,
device=device,
dtype=dtype,
process_group=process_group,
world_size=world_size,
rank=rank,
)
model.eval() model.eval()
if not rotary: if not rotary:
...@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
return_tensors="pt").input_ids.to(device=device) device=device
)
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cur_input_ids = input_ids cur_input_ids = input_ids
with torch.inference_mode(): with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)', logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
b=input_ids.shape[0])[..., :config.vocab_size] ..., : config.vocab_size
]
scores.append(logits) scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length): for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)', logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
b=input_ids.shape[0])[..., :config.vocab_size] ..., : config.vocab_size
]
scores.append(logits) scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1)) sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores) scores = tuple(scores)
print(sequences) print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, out = model.generate(
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True) max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences) print(out.sequences)
if fused_ft_kernel: if fused_ft_kernel:
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, input_ids=input_ids,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True, max_length=max_length,
return_dict_in_generate=True, output_scores=True, timing=True) tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences) print(out_cg.sequences)
if not rotary: if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids,
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, max_length=max_length,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True,
output_scores=True,
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') )
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') out_ref = model_ref.generate(
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') input_ids=input_ids,
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(
rtol=rtol, atol=atol) torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary: if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
...@@ -2,37 +2,37 @@ ...@@ -2,37 +2,37 @@
import time import time
import torch
import pytest import pytest
import torch
from transformers import GPTNeoXConfig, AutoTokenizer
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name): def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_gpt_neox(
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys(): for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gpt_neox_optimized(model_name): def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name): ...@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name):
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
logits = model(input_ids).logits logits = model(input_ids).logits
...@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name): ...@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name):
# Need at least 2 GPUs, otherwise we'll OOM # Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto') model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device) out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref del model_ref
model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTNeoXForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
del model_hf del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item() assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 2 * (
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item() logits_hf - logits_ref
).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (
logits_hf - logits_ref
).abs().mean().item()
...@@ -3,33 +3,29 @@ ...@@ -3,33 +3,29 @@
import math import math
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import GPT2Config
from apex.transformer import parallel_state from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
...@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
num_layers = 2 num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2) rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device) g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers, config = GPT2Config(
n_positions=seqlen if has_pos_emb else 0, n_embd=dim,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, n_head=num_heads,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, n_layer=num_layers,
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, n_positions=seqlen if has_pos_emb else 0,
residual_in_fp32=True, vocab_size=50257,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5, resid_pdrop=0.0,
pad_vocab_size_multiple=8 * world_size, embd_pdrop=0.0,
sequence_parallel=sequence_parallel) attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size) config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device) model_pt = GPTLMHeadModel(config, device=device)
...@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight) nn.init.normal_(module.weight)
nn.init.normal_(module.bias) nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm) model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device) model = GPTLMHeadModel(config, process_group=process_group, device=device)
...@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
) )
shared_nparams = sum(p.numel() for p in model.parameters() shared_nparams = sum(
if getattr(p, '_shared_params', False)) p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
) )
assert torch.all(shared_nparams_all == shared_nparams) assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item() assert total_nparams == (
+ shared_nparams) (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here # vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size partition_vocab_size = config.vocab_size // world_size
...@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights() model.tie_weights()
with torch.autocast(device_type='cuda', dtype=dtype): with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits out = model(input_ids[:, :-1]).logits
if not sequence_parallel: if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d') out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size], out,
rtol=rtol, atol=atol out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
) )
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group) loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none') loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten()) loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten()) loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol) assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
...@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
allreduce_sequence_parallel_grad(model, process_group) allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()}, grad_dict = shard_state_dict_tp(
config, world_size, rank) {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose( assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad, model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict['transformer.embeddings.word_embeddings.weight'], grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
if has_pos_emb: if has_pos_emb:
assert torch.allclose( assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad, model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict['transformer.embeddings.position_embeddings.weight'], grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol, atol=atol rtol=rtol,
atol=atol,
) )
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'], assert torch.allclose(
rtol=rtol, atol=atol) model.transformer.ln_f.weight.grad,
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'], grad_dict["transformer.ln_f.weight"],
rtol=rtol, atol=atol) rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
)
for i in range(num_layers): for i in range(num_layers):
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad, model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'], grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad, model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'], grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad, model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'], grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'], model.transformer.layers[i].mixer.out_proj.bias.grad,
rtol=rtol, atol=atol * 5) grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad, model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'], grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad, model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'], grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad, model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'], grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'], model.transformer.layers[i].mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5) grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, atol=atol * 5,
grad_dict[f'transformer.layers.{i}.norm1.weight'], )
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.norm1.bias'], model.transformer.layers[i].norm1.weight.grad,
rtol=rtol, atol=atol) grad_dict[f"transformer.layers.{i}.norm1.weight"],
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, rtol=rtol,
grad_dict[f'transformer.layers.{i}.norm2.weight'], atol=atol,
rtol=rtol, atol=atol) )
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, assert torch.allclose(
grad_dict[f'transformer.layers.{i}.norm2.bias'], model.transformer.layers[i].norm1.bias.grad,
rtol=rtol, atol=atol) grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
...@@ -2,37 +2,35 @@ ...@@ -2,37 +2,35 @@
import time import time
import torch
import pytest import pytest
import torch
from transformers import GPTJConfig, AutoTokenizer
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name): def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config) pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys(): for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name): def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name): ...@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad(): with torch.no_grad():
out = model.transformer(input_ids) out = model.transformer(input_ids)
logits = model(input_ids).logits logits = model(input_ids).logits
...@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name): ...@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name):
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
del model_ref del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTJForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
del model_hf del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name): def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -104,56 +106,71 @@ def test_gptj_generation(model_name): ...@@ -104,56 +106,71 @@ def test_gptj_generation(model_name):
batch_size = 1 batch_size = 1
seqlen = 100 seqlen = 100
max_length = 150 max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, input_ids = torch.randint(
device=device) 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, model_hf = GPTJForCausalLM.from_pretrained(
device_map={"": device}) model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval() model_hf.eval()
print("HF fp16") print("HF fp16")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(
return_dict_in_generate=True, output_scores=True) input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf del model_hf
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval() model.eval()
print('Without CUDA graph') print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(
eos_token_id=eos_token_id, fused_ft_kernel=True, input_ids=input_ids,
# eos_token_id=eos_token_id, fused_ft_kernel=False, max_length=max_length,
return_dict_in_generate=True, output_scores=True, timing=True, eos_token_id=eos_token_id,
teacher_outputs=out_hf.sequences) fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph') print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(
fused_ft_kernel=True, cg=True, input_ids=input_ids,
return_dict_in_generate=True, output_scores=True, timing=True, max_length=max_length,
teacher_outputs=out_hf.sequences) fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad(): with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1] logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1) logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1) logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1)
...@@ -163,8 +180,8 @@ def test_gptj_generation(model_name): ...@@ -163,8 +180,8 @@ def test_gptj_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}') print(f"HF fp16 logits max diff: {hf_error}")
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
...@@ -11,26 +11,25 @@ from pathlib import Path ...@@ -11,26 +11,25 @@ from pathlib import Path
current_dir = Path(__file__).parent.absolute() current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import shutil import shutil
import pytest
import torch
from einops import rearrange from einops import rearrange
from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import ( from flash_attn.models.llama import (
remap_state_dict_meta_llama, config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config, llama_config_to_gpt2_config,
remap_state_dict_hf_llama, remap_state_dict_hf_llama,
inv_remap_state_dict_hf_llama, remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
) )
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
......
import re import re
import torch
import pytest import pytest
import torch
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) @pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name): def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
...@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name): ...@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) @pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name): def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without all optimizations enabled) matches the """Check that our implementation of OPT (without all optimizations enabled) matches the
...@@ -31,14 +33,14 @@ def test_opt_optimized(model_name): ...@@ -31,14 +33,14 @@ def test_opt_optimized(model_name):
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True) config.residual_in_fp32 = getattr(config, "prenorm", True)
config.pad_vocab_size_multiple = 8 config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
...@@ -53,26 +55,29 @@ def test_opt_optimized(model_name): ...@@ -53,26 +55,29 @@ def test_opt_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(
device='cuda') 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512 )
if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids) out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state out_ref = model_ref.model(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
import re import re
import torch
import pytest import pytest
import torch
from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
from timm.models.vision_transformer import vit_base_patch16_224
@pytest.mark.parametrize('fused_mlp', [False, True]) @pytest.mark.parametrize("fused_mlp", [False, True])
# @pytest.mark.parametrize('fused_mlp', [False]) # @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_mlp): def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation: """Check that our implementation of ViT matches the timm's implementation:
...@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp): ...@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp):
timm' forward pass in fp16, when compared to timm's forward pass in fp32. timm' forward pass in fp16, when compared to timm's forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = "cuda"
kwargs = {} kwargs = {}
if optimized: if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_mlp'] = fused_mlp kwargs["fused_mlp"] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device) model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
...@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp): ...@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp):
out_timm = model_timm(x) out_timm = model_timm(x)
out_ref = model_ref(x.float()) out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
rtol = 2 if not fused_mlp else 8 rtol = 2 if not fused_mlp else 8
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
...@@ -4,31 +4,27 @@ ...@@ -4,31 +4,27 @@
import math import math
from functools import partial from functools import partial
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from flash_attn.modules.block import Block
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [True]) # @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype): def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
...@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
assert num_heads % world_size == 0 assert num_heads % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
requires_grad=True)
residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True) residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_() tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
residual = (
tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_() residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), mixer_cls_pt = partial(
use_flash_attn=True, device=device, dtype=dtype) MHA,
num_heads=num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype) mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype) norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True) model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
...@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
nn.init.normal_(model_pt.norm2.weight) nn.init.normal_(model_pt.norm2.weight)
nn.init.normal_(model_pt.norm2.bias) nn.init.normal_(model_pt.norm2.bias)
mixer_cls = partial(ParallelMHA, num_heads=num_heads, mixer_cls = partial(
process_group=parallel_state.get_tensor_model_parallel_group(), ParallelMHA,
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, num_heads=num_heads,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) process_group=parallel_state.get_tensor_model_parallel_group(),
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim, rotary_emb_dim=int(head_dim // 2),
process_group=parallel_state.get_tensor_model_parallel_group(), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel,
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, device=device,
sequence_parallel=sequence_parallel, mark_shared_params=True) dtype=dtype,
)
mlp_cls = partial(
ParallelFusedMLP,
hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
model = Block(
dim,
mixer_cls,
mlp_cls,
norm_cls,
fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel,
mark_shared_params=True,
)
partition_dim = dim // world_size partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size partition_hidden_dim = 4 * dim // world_size
with torch.no_grad(): with torch.no_grad():
model.mixer.Wqkv.weight.copy_( model.mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i') rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
) )
model.mixer.Wqkv.bias.copy_( model.mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)') rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
) )
model.mixer.out_proj.weight.copy_( model.mixer.out_proj.weight.copy_(
model_pt.mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias) model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
model.mlp.fc1.weight.copy_( model.mlp.fc1.weight.copy_(
model_pt.mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
) )
model.mlp.fc1.bias.copy_( model.mlp.fc1.bias.copy_(
model_pt.mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
) )
model.mlp.fc2.weight.copy_( model.mlp.fc2.weight.copy_(
model_pt.mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] model_pt.mlp.fc2.weight[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
]
) )
if rank == 0: if rank == 0:
model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias) model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
...@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
model.norm2.weight.copy_(model_pt.norm2.weight) model.norm2.weight.copy_(model_pt.norm2.weight)
model.norm2.bias.copy_(model_pt.norm2.bias) model.norm2.bias.copy_(model_pt.norm2.bias)
mixer_kwargs = {'seqlen': seqlen} mixer_kwargs = {"seqlen": seqlen}
out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs) out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
out_pt, out_residual_pt = model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen), out_pt, out_residual_pt = model_pt(
rearrange(residual_pt, '(b s) d -> b s d', s=seqlen)) rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]] rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
)
out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
out_residual, out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt, if sequence_parallel
rtol=rtol, atol=atol else out_residual_pt,
rtol=rtol,
atol=atol,
) )
(out_pt + 2 * out_residual_pt).backward(g) (out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] (out + 2 * out_residual).backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group()) allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small else x_pt.grad,
rtol=rtol,
atol=atol / 10, # magnitude of x.grad is quite small
) )
assert torch.allclose( assert torch.allclose(
residual.grad, residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else residual_pt.grad,
rtol=rtol,
atol=atol,
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.mixer.Wqkv.weight.grad, model.mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i'), rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
rtol=rtol, atol=atol * 10 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.mixer.Wqkv.bias.grad, model.mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)'), rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
rtol=rtol, atol=atol * 5 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.mixer.out_proj.weight.grad, model.mixer.out_proj.weight.grad,
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.mixer.out_proj.bias.grad, model_pt.mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.mixer.out_proj.bias.grad,
model_pt.mixer.out_proj.bias.grad,
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose( assert torch.allclose(
model.mlp.fc1.weight.grad, model.mlp.fc1.weight.grad,
model_pt.mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc1.weight.grad[
rtol=rtol, atol=atol * 10 rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.mlp.fc1.bias.grad, model.mlp.fc1.bias.grad,
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 5 rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.mlp.fc2.weight.grad, model.mlp.fc2.weight.grad,
model_pt.mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], model_pt.mlp.fc2.weight.grad[
rtol=rtol, atol=atol * 10 :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, assert torch.allclose(
rtol=rtol, atol=atol * 5) model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with: # Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize("dim", [1024])
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264 vocab_size = 50264
seqlen = 2048 seqlen = 2048
...@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty ...@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
assert dim % world_size == 0 assert dim % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty ...@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
input_ids = input_ids_pt.detach().clone() input_ids = input_ids_pt.detach().clone()
model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, model_pt = GPT2Embeddings(
device=device, dtype=dtype) dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, )
parallel_state.get_tensor_model_parallel_group(), model = ParallelGPT2Embeddings(
sequence_parallel=sequence_parallel, device=device, dtype=dtype) dim,
vocab_size,
seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
partition_vocab_size = vocab_size // world_size partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size partition_dim = dim // world_size
with torch.no_grad(): with torch.no_grad():
model.word_embeddings.weight.copy_( model.word_embeddings.weight.copy_(
model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size] model_pt.word_embeddings.weight[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
]
) )
if has_pos_emb: if has_pos_emb:
model.position_embeddings.weight.copy_( model.position_embeddings.weight.copy_(
model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.position_embeddings.weight[
:, rank * partition_dim : (rank + 1) * partition_dim
]
) )
out = model(input_ids, combine_batch_seqlen_dim=True) out = model(input_ids, combine_batch_seqlen_dim=True)
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
g = torch.randn_like(out_pt) g = torch.randn_like(out_pt)
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
model.word_embeddings.weight.grad, model.word_embeddings.weight.grad,
model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size], model_pt.word_embeddings.weight.grad[
rtol=rtol, atol=atol rank * partition_vocab_size : (rank + 1) * partition_vocab_size
],
rtol=rtol,
atol=atol,
) )
if has_pos_emb: if has_pos_emb:
assert torch.allclose( assert torch.allclose(
model.position_embeddings.weight.grad, model.position_embeddings.weight.grad,
model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.position_embeddings.weight.grad[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
rtol=rtol,
atol=atol,
) )
...@@ -3,29 +3,25 @@ ...@@ -3,29 +3,25 @@
import math import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128]) @pytest.mark.parametrize("head_dim", [64, 128])
# @pytest.mark.parametrize('head_dim', [64]) # @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096]) @pytest.mark.parametrize("embed_dim", [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024]) # @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype): def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0 assert embed_dim % head_dim == 0
...@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype) ...@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
assert num_heads % world_size == 0 assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype) ...@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype, x_pt = torch.randn(
requires_grad=True) batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), model_pt = MHA(
use_flash_attn=True, device=device, dtype=dtype) embed_dim,
num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
partition_dim = embed_dim // world_size partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), model = ParallelMHA(
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, embed_dim,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) num_heads,
parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.Wqkv.weight.copy_( model.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i') rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
) )
model.Wqkv.bias.copy_( model.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)') rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
) )
model.out_proj.weight.copy_( model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias) model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen) out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d') out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small else x_pt.grad,
rtol=rtol,
atol=atol / 100, # magnitude of x.grad is quite small
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
model.Wqkv.weight.grad, model.Wqkv.weight.grad,
rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o i -> (three o) i'), rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
rtol=rtol, atol=atol * 10 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
) )
assert torch.allclose( assert torch.allclose(
model.Wqkv.bias.grad, model.Wqkv.bias.grad,
rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'three o -> (three o)'), rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
rtol=rtol, atol=atol * 5 :, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
) )
assert torch.allclose( assert torch.allclose(
model.out_proj.weight.grad, model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10 rtol=rtol,
atol=atol * 10,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose(
model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
)
# Run test with: # Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('activation', [F.silu, F.sigmoid]) @pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu]) # @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize('dim', [1024, 4096]) @pytest.mark.parametrize("dim", [1024, 4096])
# @pytest.mark.parametrize('dim', [1024]) # @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f'cuda:{torch.distributed.get_rank()}' device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size() assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
...@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): ...@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
batch_size = 2 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
requires_grad=True)
# We need to generate g here so that all processes get the same gradient, # We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
if sequence_parallel: if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else: else:
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype) model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(), model = ParallelGatedMlp(
activation=activation, dim,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_( model.fc1.weight.copy_(
rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o i -> (two o) i') rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
)
) )
model.fc1.bias.copy_( model.fc1.bias.copy_(
rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o -> (two o)') rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
)
) )
model.fc2.weight.copy_( model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim] model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
) )
if rank == 0: if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias) model.fc2.bias.copy_(model_pt.fc2.bias)
...@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): ...@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt, if sequence_parallel
rtol=rtol, atol=atol else out_pt,
rtol=rtol,
atol=atol,
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] out.backward(
if sequence_parallel else g) g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel
rtol=rtol, atol=atol else x_pt.grad,
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc1.weight.grad, model.fc1.weight.grad,
rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o i -> (two o) i'), rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
),
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc1.bias.grad, model.fc1.bias.grad,
rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim], rearrange(
'two o -> (two o)'), rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
rtol=rtol, atol=atol :, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
),
rtol=rtol,
atol=atol,
) )
assert torch.allclose( assert torch.allclose(
model.fc2.weight.grad, model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol, atol=atol rtol=rtol,
atol=atol,
) )
if rank == 0: if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol) assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment