Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
......@@ -2,26 +2,24 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
from transformers.models.gpt_neox.modeling_gpt_neox import (
apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
# NeoX-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
......@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset):
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
q_pt = (
rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
k_pt = (
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.allclose(
rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d'))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d'))
assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
assert torch.allclose(
rearrange(q_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 0, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
rearrange(k_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 1, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
......@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
......
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
@pytest.mark.parametrize('vocab_size', [50257])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
@pytest.mark.parametrize("vocab_size", [50257])
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
device = 'cuda'
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
......
......@@ -3,35 +3,37 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9])
@pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("vocab_size", [50264])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
rtol, atol = (
(1e-5, 1e-6)
if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.distributed.init_process_group(backend="nccl", init_method="env://")
partition_vocab_size = vocab_size // world_size
device = f'cuda:{torch.distributed.get_rank()}'
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
dtype=dtype) * 10).requires_grad_()
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
x_pt = (
torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
).requires_grad_()
x = (
tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group())
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(
label_smoothing=smoothing,
reduction="none",
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
......@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)
assert torch.allclose(
x.grad,
x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
parallel_state.destroy_model_parallel()
import re
from collections import OrderedDict
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name)
......@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name):
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items())
pretrained_state_dict = OrderedDict(
(key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
)
model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
......@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
......@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
......@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name):
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item()
print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
......@@ -117,10 +121,11 @@ def test_bert_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
......@@ -131,12 +136,24 @@ def test_bert_optimized(model_name):
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
print(
f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
)
print(
f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
)
print(
f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
)
print(
f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
)
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
......@@ -144,25 +161,43 @@ def test_bert_optimized(model_name):
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True])
@pytest.mark.parametrize("last_layer_subset", [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False])
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
......@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
else:
attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
labels = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if attention_mask is not None:
labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0
labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
out = model(
input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask]
out_ref = model_ref(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask]
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
out_hf = model_hf(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
out_ref = model_ref(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
This diff is collapsed.
import re
import torch
import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name)
......@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the
......@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the
......@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits[..., :vocab_size_og]
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
......@@ -2,36 +2,32 @@ import os
import re
import time
import torch
import pytest
import torch
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPT2Config, GPT2Tokenizer, OPTConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
......@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype)
model = GPTLMHeadModel.from_pretrained(
model_name, config, strict=not rotary, device=device, dtype=dtype
)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name,
torch_dtype=dtype).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
device=device
)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
@pytest.mark.parametrize(
"model_name",
[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print(f'\nMODEL: {model_name}')
print(f"\nMODEL: {model_name}")
verbose = False
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
......@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name):
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
......@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name):
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length
)
print('With CUDA graph')
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out_cg.sequences)
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
......@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name):
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
......@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name):
print("HF fp32")
torch.cuda.synchronize()
start = time.time()
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose:
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
......@@ -2,34 +2,37 @@ import os
import re
import time
import torch
import pytest
import torch
from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.utils.generation import update_graph_cache
from transformers import GPT2Config
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True,
teacher_outputs=teacher_outputs, return_dict_in_generate=True,
output_scores=True, timing=True, **kwargs)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1)
@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"])
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.
"""
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024
......@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
torch.manual_seed(0)
batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
......@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3
maxlen += 30
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
batch_size = 2
maxlen -= 35
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
......@@ -3,27 +3,23 @@
import os
import re
import torch
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("fused_ft_kernel", [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
......@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model = GPTLMHeadModel.from_pretrained(
model_name,
config,
strict=not rotary,
device=device,
dtype=dtype,
process_group=process_group,
world_size=world_size,
rank=rank,
)
model.eval()
if not rotary:
......@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
device=device
)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel()
......@@ -2,37 +2,37 @@
import time
import torch
import pytest
from transformers import GPTNeoXConfig, AutoTokenizer
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
pretrained_state_dict = remap_state_dict_hf_gpt_neox(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
......@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name):
# Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto')
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTNeoXForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (
logits_hf - logits_ref
).abs().mean().item()
......@@ -3,33 +3,29 @@
import math
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import GPT2Config
from apex.transformer import parallel_state
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
......@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel)
config = GPT2Config(
n_embd=dim,
n_head=num_heads,
n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
......@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight)
nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device)
......@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
shared_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_shared_params', False))
shared_nparams = sum(
p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
)
assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
+ shared_nparams)
assert total_nparams == (
(sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
......@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights()
with torch.autocast(device_type='cuda', dtype=dtype):
with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol
out,
out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none')
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
......@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()},
config, world_size, rank)
grad_dict = shard_state_dict_tp(
{k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict['transformer.embeddings.word_embeddings.weight'],
rtol=rtol, atol=atol * 5
grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol,
atol=atol * 5,
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict['transformer.embeddings.position_embeddings.weight'],
rtol=rtol, atol=atol
grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(
model.transformer.ln_f.weight.grad,
grad_dict["transformer.ln_f.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
)
for i in range(num_layers):
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad,
grad_dict[f'transformer.layers.{i}.norm1.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad,
grad_dict[f'transformer.layers.{i}.norm1.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad,
grad_dict[f'transformer.layers.{i}.norm2.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad,
grad_dict[f'transformer.layers.{i}.norm2.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].norm1.weight.grad,
grad_dict[f"transformer.layers.{i}.norm1.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm1.bias.grad,
grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
......@@ -2,37 +2,35 @@
import time
import torch
import pytest
from transformers import GPTJConfig, AutoTokenizer
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
......@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
......@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name):
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
......@@ -104,56 +106,71 @@ def test_gptj_generation(model_name):
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
......@@ -163,8 +180,8 @@ def test_gptj_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
......@@ -11,26 +11,25 @@ from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import shutil
import pytest
import torch
from einops import rearrange
from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import (
remap_state_dict_meta_llama,
config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config,
remap_state_dict_hf_llama,
inv_remap_state_dict_hf_llama,
remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
)
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
......
import re
import torch
import pytest
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
......@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without all optimizations enabled) matches the
......@@ -31,14 +33,14 @@ def test_opt_optimized(model_name):
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......@@ -53,26 +55,29 @@ def test_opt_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
import re
import torch
import pytest
from timm.models.vision_transformer import vit_base_patch16_224
import torch
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
from timm.models.vision_transformer import vit_base_patch16_224
@pytest.mark.parametrize('fused_mlp', [False, True])
@pytest.mark.parametrize("fused_mlp", [False, True])
# @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation:
......@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp):
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_mlp'] = fused_mlp
kwargs["fused_mlp"] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
......@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp):
out_timm = model_timm(x)
out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
rtol = 2 if not fused_mlp else 8
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
......@@ -4,31 +4,27 @@
import math
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
......@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
assert num_heads % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
residual = (
tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
mixer_cls_pt = partial(
MHA,
num_heads=num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
......@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
nn.init.normal_(model_pt.norm2.weight)
nn.init.normal_(model_pt.norm2.bias)
mixer_cls = partial(ParallelMHA, num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel, mark_shared_params=True)
mixer_cls = partial(
ParallelMHA,
num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
mlp_cls = partial(
ParallelFusedMLP,
hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
model = Block(
dim,
mixer_cls,
mlp_cls,
norm_cls,
fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel,
mark_shared_params=True,
)
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
rearrange(
rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
rearrange(
rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.mixer.out_proj.weight.copy_(
model_pt.mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
model.mlp.fc1.weight.copy_(
model_pt.mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc1.bias.copy_(
model_pt.mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc2.weight.copy_(
model_pt.mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc2.weight[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
]
)
if rank == 0:
model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
......@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
model.norm2.weight.copy_(model_pt.norm2.weight)
model.norm2.bias.copy_(model_pt.norm2.bias)
mixer_kwargs = {'seqlen': seqlen}
mixer_kwargs = {"seqlen": seqlen}
out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
out_pt, out_residual_pt = model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen),
rearrange(residual_pt, '(b s) d -> b s d', s=seqlen))
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]]
out_pt, out_residual_pt = model_pt(
rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
)
out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt,
rtol=rtol, atol=atol
out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_residual_pt,
rtol=rtol,
atol=atol,
)
(out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
(out + 2 * out_residual).backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 10, # magnitude of x.grad is quite small
)
assert torch.allclose(
residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad,
rtol=rtol, atol=atol
residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else residual_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
rearrange(
rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 5
rearrange(
rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mixer.out_proj.weight.grad,
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.mixer.out_proj.bias.grad, model_pt.mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.mixer.out_proj.bias.grad,
model_pt.mixer.out_proj.bias.grad,
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc1.weight.grad,
model_pt.mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
model_pt.mlp.fc1.weight.grad[
rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mlp.fc1.bias.grad,
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 5
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc2.weight.grad,
model_pt.mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
model_pt.mlp.fc2.weight.grad[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264
seqlen = 2048
......@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
assert dim % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
input_ids = input_ids_pt.detach().clone()
model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
device=device, dtype=dtype)
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model_pt = GPT2Embeddings(
dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
)
model = ParallelGPT2Embeddings(
dim,
vocab_size,
seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size
with torch.no_grad():
model.word_embeddings.weight.copy_(
model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
model_pt.word_embeddings.weight[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
]
)
if has_pos_emb:
model.position_embeddings.weight.copy_(
model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.position_embeddings.weight[
:, rank * partition_dim : (rank + 1) * partition_dim
]
)
out = model(input_ids, combine_batch_seqlen_dim=True)
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
g = torch.randn_like(out_pt)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
model.word_embeddings.weight.grad,
model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol
model_pt.word_embeddings.weight.grad[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
],
rtol=rtol,
atol=atol,
)
if has_pos_emb:
assert torch.allclose(
model.position_embeddings.weight.grad,
model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
model_pt.position_embeddings.weight.grad[
:, rank * partition_dim : (rank + 1) * partition_dim
],
rtol=rtol,
atol=atol,
)
......@@ -3,29 +3,25 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128])
@pytest.mark.parametrize("head_dim", [64, 128])
# @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096])
@pytest.mark.parametrize("embed_dim", [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0
......@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
model_pt = MHA(
embed_dim,
num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = ParallelMHA(
embed_dim,
num_heads,
parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
rearrange(
rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
rearrange(
rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 100, # magnitude of x.grad is quite small
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.Wqkv.weight.grad,
rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
rearrange(
rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.Wqkv.bias.grad,
rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 5
rearrange(
rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('activation', [F.silu, F.sigmoid])
@pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize('dim', [1024, 4096])
@pytest.mark.parametrize("dim", [1024, 4096])
# @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = ParallelGatedMlp(
dim,
parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(
rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i')
rearrange(
rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
)
)
model.fc1.bias.copy_(
rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)')
rearrange(
rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
)
)
model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias)
......@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.weight.grad,
rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i'),
rtol=rtol, atol=atol
rearrange(
rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.bias.grad,
rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)'),
rtol=rtol, atol=atol
rearrange(
rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol,
)
if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment