Commit 0e8c46ae authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on test files

parent 7fcd3e6a
......@@ -2,26 +2,24 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
from transformers.models.gpt_neox.modeling_gpt_neox import (
apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
# NeoX-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
......@@ -32,49 +30,70 @@ def test_rotary(rotary_emb_fraction, seqlen_offset):
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
q_pt = (
rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
k_pt = (
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.allclose(
rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d'))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d'))
assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
assert torch.allclose(
rearrange(q_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 0, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
rearrange(k_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 1, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
......@@ -85,8 +104,9 @@ def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
......
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
@pytest.mark.parametrize('vocab_size', [50257])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
@pytest.mark.parametrize("vocab_size", [50257])
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
device = 'cuda'
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
......
......@@ -3,35 +3,37 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9])
@pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("vocab_size", [50264])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
rtol, atol = (
(1e-5, 1e-6)
if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.distributed.init_process_group(backend="nccl", init_method="env://")
partition_vocab_size = vocab_size // world_size
device = f'cuda:{torch.distributed.get_rank()}'
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -39,15 +41,24 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
dtype=dtype) * 10).requires_grad_()
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
x_pt = (
torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
).requires_grad_()
x = (
tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group())
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(
label_smoothing=smoothing,
reduction="none",
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
......@@ -55,6 +66,11 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)
assert torch.allclose(
x.grad,
x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
parallel_state.destroy_model_parallel()
import re
from collections import OrderedDict
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name)
......@@ -30,12 +26,15 @@ def test_bert_state_dict(model_name):
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items())
pretrained_state_dict = OrderedDict(
(key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
)
model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
......@@ -44,7 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
......@@ -67,10 +66,11 @@ def test_bert_non_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
......@@ -78,15 +78,19 @@ def test_bert_non_optimized(model_name):
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item()
print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
......@@ -117,10 +121,11 @@ def test_bert_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
......@@ -131,12 +136,24 @@ def test_bert_optimized(model_name):
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
print(
f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
)
print(
f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
)
print(
f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
)
print(
f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
)
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
......@@ -144,25 +161,43 @@ def test_bert_optimized(model_name):
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True])
@pytest.mark.parametrize("last_layer_subset", [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False])
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
......@@ -196,40 +231,70 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
else:
attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
labels = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if attention_mask is not None:
labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0
labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
out = model(
input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask]
out_ref = model_ref(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask]
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
out_hf = model_hf(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
out_ref = model_ref(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
......@@ -3,44 +3,46 @@
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import torch
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.falcon import remap_state_dict_hf_falcon, falcon_config_to_gpt2_config
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
def test_falcon_state_dict(model_name):
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_optimized(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
device = "cuda"
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
......@@ -53,8 +55,9 @@ def test_falcon_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
......@@ -78,30 +81,33 @@ def test_falcon_optimized(model_name):
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_forward(model_name, world_size):
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
......@@ -109,14 +115,16 @@ def test_falcon_parallel_forward(model_name, world_size):
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
......@@ -126,8 +134,9 @@ def test_falcon_parallel_forward(model_name, world_size):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
......@@ -135,7 +144,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
if rank == 0:
......@@ -157,29 +166,32 @@ def test_falcon_parallel_forward(model_name, world_size):
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_generation(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
device = "cuda"
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
......@@ -193,8 +205,9 @@ def test_falcon_generation(model_name):
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
......@@ -203,10 +216,11 @@ def test_falcon_generation(model_name):
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
......@@ -214,37 +228,49 @@ def test_falcon_generation(model_name):
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
......@@ -254,18 +280,18 @@ def test_falcon_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
......@@ -274,8 +300,9 @@ def test_falcon_parallel_generation(model_name, world_size):
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
......@@ -286,8 +313,8 @@ def test_falcon_parallel_generation(model_name, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -297,36 +324,50 @@ def test_falcon_parallel_generation(model_name, world_size):
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
out = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True,
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
del model
parallel_state.destroy_model_parallel()
......@@ -341,11 +382,13 @@ def test_falcon_parallel_generation(model_name, world_size):
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
output_scores=True
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
......@@ -353,7 +396,7 @@ def test_falcon_parallel_generation(model_name, world_size):
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
......@@ -361,8 +404,8 @@ def test_falcon_parallel_generation(model_name, world_size):
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
import re
import torch
import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name)
......@@ -23,7 +20,7 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the
......@@ -46,31 +43,34 @@ def test_gpt2_non_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the
......@@ -100,25 +100,28 @@ def test_gpt2_optimized(model_name):
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits[..., :vocab_size_og]
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
......@@ -2,36 +2,32 @@ import os
import re
import time
import torch
import pytest
import torch
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPT2Config, GPT2Tokenizer, OPTConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
......@@ -47,21 +43,24 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype)
model = GPTLMHeadModel.from_pretrained(
model_name, config, strict=not rotary, device=device, dtype=dtype
)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name,
torch_dtype=dtype).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
device=device
)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -74,61 +73,102 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
@pytest.mark.parametrize(
"model_name",
[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print(f'\nMODEL: {model_name}')
print(f"\nMODEL: {model_name}")
verbose = False
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
......@@ -143,8 +183,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -157,7 +198,7 @@ def test_greedy_decode_opt(model_name):
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
......@@ -165,31 +206,41 @@ def test_greedy_decode_opt(model_name):
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length
)
print('With CUDA graph')
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out_cg.sequences)
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
......@@ -201,10 +252,11 @@ def test_greedy_decode_opt(model_name):
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
......@@ -212,23 +264,35 @@ def test_greedy_decode_opt(model_name):
print("HF fp32")
torch.cuda.synchronize()
start = time.time()
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose:
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
......@@ -2,34 +2,37 @@ import os
import re
import time
import torch
import pytest
import torch
from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.utils.generation import update_graph_cache
from transformers import GPT2Config
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True,
teacher_outputs=teacher_outputs, return_dict_in_generate=True,
output_scores=True, timing=True, **kwargs)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1)
@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"])
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.
"""
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16
device = 'cuda'
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024
......@@ -49,10 +52,12 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
torch.manual_seed(0)
batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
......@@ -61,20 +66,24 @@ def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3
maxlen += 30
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
batch_size = 2
maxlen -= 35
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
......@@ -3,27 +3,23 @@
import os
import re
import torch
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("fused_ft_kernel", [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
......@@ -45,23 +41,31 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model = GPTLMHeadModel.from_pretrained(
model_name,
config,
strict=not rotary,
device=device,
dtype=dtype,
process_group=process_group,
world_size=world_size,
rank=rank,
)
model.eval()
if not rotary:
......@@ -72,8 +76,9 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
device=device
)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
......@@ -84,50 +89,87 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel()
......@@ -2,37 +2,37 @@
import time
import torch
import pytest
from transformers import GPTNeoXConfig, AutoTokenizer
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
pretrained_state_dict = remap_state_dict_hf_gpt_neox(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -47,8 +47,9 @@ def test_gpt_neox_optimized(model_name):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
......@@ -56,31 +57,36 @@ def test_gpt_neox_optimized(model_name):
# Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto')
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTNeoXForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (
logits_hf - logits_ref
).abs().mean().item()
......@@ -3,33 +3,29 @@
import math
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import GPT2Config
from apex.transformer import parallel_state
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
......@@ -40,8 +36,8 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -57,15 +53,25 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel)
config = GPT2Config(
n_embd=dim,
n_head=num_heads,
n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
......@@ -73,6 +79,7 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight)
nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device)
......@@ -82,15 +89,17 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
shared_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_shared_params', False))
shared_nparams = sum(
p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
)
assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
+ shared_nparams)
assert total_nparams == (
(sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
......@@ -100,18 +109,20 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights()
with torch.autocast(device_type='cuda', dtype=dtype):
with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol
out,
out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none')
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
......@@ -121,73 +132,105 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()},
config, world_size, rank)
grad_dict = shard_state_dict_tp(
{k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict['transformer.embeddings.word_embeddings.weight'],
rtol=rtol, atol=atol * 5
grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol,
atol=atol * 5,
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict['transformer.embeddings.position_embeddings.weight'],
rtol=rtol, atol=atol
grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(
model.transformer.ln_f.weight.grad,
grad_dict["transformer.ln_f.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
)
for i in range(num_layers):
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'],
rtol=rtol, atol=atol * 10
grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad,
grad_dict[f'transformer.layers.{i}.norm1.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad,
grad_dict[f'transformer.layers.{i}.norm1.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad,
grad_dict[f'transformer.layers.{i}.norm2.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad,
grad_dict[f'transformer.layers.{i}.norm2.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].norm1.weight.grad,
grad_dict[f"transformer.layers.{i}.norm1.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm1.bias.grad,
grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
......@@ -2,37 +2,35 @@
import time
import torch
import pytest
from transformers import GPTJConfig, AutoTokenizer
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
......@@ -46,8 +44,9 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
......@@ -61,34 +60,37 @@ def test_gptj_optimized(model_name):
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
......@@ -104,56 +106,71 @@ def test_gptj_generation(model_name):
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
device_map={"": device})
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
......@@ -163,8 +180,8 @@ def test_gptj_generation(model_name):
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
......@@ -11,26 +11,25 @@ from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import shutil
import pytest
import torch
from einops import rearrange
from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import (
remap_state_dict_meta_llama,
config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config,
remap_state_dict_hf_llama,
inv_remap_state_dict_hf_llama,
remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
)
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
......
import re
import torch
import pytest
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
......@@ -23,7 +23,9 @@ def test_opt_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without all optimizations enabled) matches the
......@@ -31,14 +33,14 @@ def test_opt_optimized(model_name):
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......@@ -53,26 +55,29 @@ def test_opt_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
import re
import torch
import pytest
from timm.models.vision_transformer import vit_base_patch16_224
import torch
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
from timm.models.vision_transformer import vit_base_patch16_224
@pytest.mark.parametrize('fused_mlp', [False, True])
@pytest.mark.parametrize("fused_mlp", [False, True])
# @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation:
......@@ -18,12 +16,12 @@ def test_vit(optimized, fused_mlp):
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
device = "cuda"
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_mlp'] = fused_mlp
kwargs["fused_mlp"] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
......@@ -42,9 +40,9 @@ def test_vit(optimized, fused_mlp):
out_timm = model_timm(x)
out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
rtol = 2 if not fused_mlp else 8
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
......@@ -4,31 +4,27 @@
import math
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
......@@ -36,8 +32,8 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
assert num_heads % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -46,22 +42,37 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
residual = (
tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
mixer_cls_pt = partial(
MHA,
num_heads=num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
......@@ -71,40 +82,68 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
nn.init.normal_(model_pt.norm2.weight)
nn.init.normal_(model_pt.norm2.bias)
mixer_cls = partial(ParallelMHA, num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel, mark_shared_params=True)
mixer_cls = partial(
ParallelMHA,
num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
mlp_cls = partial(
ParallelFusedMLP,
hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
model = Block(
dim,
mixer_cls,
mlp_cls,
norm_cls,
fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel,
mark_shared_params=True,
)
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
rearrange(
rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
rearrange(
rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.mixer.out_proj.weight.copy_(
model_pt.mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
model.mlp.fc1.weight.copy_(
model_pt.mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc1.bias.copy_(
model_pt.mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc2.weight.copy_(
model_pt.mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
model_pt.mlp.fc2.weight[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
]
)
if rank == 0:
model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
......@@ -113,83 +152,122 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
model.norm2.weight.copy_(model_pt.norm2.weight)
model.norm2.bias.copy_(model_pt.norm2.bias)
mixer_kwargs = {'seqlen': seqlen}
mixer_kwargs = {"seqlen": seqlen}
out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
out_pt, out_residual_pt = model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen),
rearrange(residual_pt, '(b s) d -> b s d', s=seqlen))
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]]
out_pt, out_residual_pt = model_pt(
rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
)
out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt,
rtol=rtol, atol=atol
out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_residual_pt,
rtol=rtol,
atol=atol,
)
(out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
(out + 2 * out_residual).backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 10, # magnitude of x.grad is quite small
)
assert torch.allclose(
residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad,
rtol=rtol, atol=atol
residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else residual_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
rearrange(
rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 5
rearrange(
rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mixer.out_proj.weight.grad,
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.mixer.out_proj.bias.grad, model_pt.mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.mixer.out_proj.bias.grad,
model_pt.mixer.out_proj.bias.grad,
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc1.weight.grad,
model_pt.mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
model_pt.mlp.fc1.weight.grad[
rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mlp.fc1.bias.grad,
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 5
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc2.weight.grad,
model_pt.mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
model_pt.mlp.fc2.weight.grad[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
@pytest.mark.parametrize("dim", [1024])
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264
seqlen = 2048
......@@ -31,8 +28,8 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
assert dim % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -44,46 +41,66 @@ def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dty
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
input_ids = input_ids_pt.detach().clone()
model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
device=device, dtype=dtype)
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model_pt = GPT2Embeddings(
dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
)
model = ParallelGPT2Embeddings(
dim,
vocab_size,
seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size
with torch.no_grad():
model.word_embeddings.weight.copy_(
model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
model_pt.word_embeddings.weight[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
]
)
if has_pos_emb:
model.position_embeddings.weight.copy_(
model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.position_embeddings.weight[
:, rank * partition_dim : (rank + 1) * partition_dim
]
)
out = model(input_ids, combine_batch_seqlen_dim=True)
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
g = torch.randn_like(out_pt)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
model.word_embeddings.weight.grad,
model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol
model_pt.word_embeddings.weight.grad[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
],
rtol=rtol,
atol=atol,
)
if has_pos_emb:
assert torch.allclose(
model.position_embeddings.weight.grad,
model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
model_pt.position_embeddings.weight.grad[
:, rank * partition_dim : (rank + 1) * partition_dim
],
rtol=rtol,
atol=atol,
)
......@@ -3,29 +3,25 @@
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128])
@pytest.mark.parametrize("head_dim", [64, 128])
# @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096])
@pytest.mark.parametrize("embed_dim", [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0
......@@ -33,8 +29,8 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -43,77 +39,122 @@ def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(
batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
model_pt = MHA(
embed_dim,
num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = ParallelMHA(
embed_dim,
num_heads,
parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
rearrange(
rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
rearrange(
rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 100, # magnitude of x.grad is quite small
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.Wqkv.weight.grad,
rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
rearrange(
rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.Wqkv.bias.grad,
rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 5
rearrange(
rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import pytest
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('activation', [F.silu, F.sigmoid])
@pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize('dim', [1024, 4096])
@pytest.mark.parametrize("dim", [1024, 4096])
# @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
......@@ -39,34 +35,51 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
requires_grad=True)
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(dim, parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = ParallelGatedMlp(
dim,
parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(
rearrange(rearrange(model_pt.fc1.weight, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i')
rearrange(
rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
)
)
model.fc1.bias.copy_(
rearrange(rearrange(model_pt.fc1.bias, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)')
rearrange(
rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
)
)
model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias)
......@@ -76,39 +89,55 @@ def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.weight.grad,
rearrange(rearrange(model_pt.fc1.weight.grad, '(two o) i -> two o i', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o i -> (two o) i'),
rtol=rtol, atol=atol
rearrange(
rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.bias.grad,
rearrange(rearrange(model_pt.fc1.bias.grad, '(two o) -> two o', two=2)[:, rank * partition_dim:(rank + 1) * partition_dim],
'two o -> (two o)'),
rtol=rtol, atol=atol
rearrange(
rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol,
)
if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)
import math
import pytest
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
from flash_attn.ops.layer_norm import (
DropoutAddLayerNorm,
dropout_add_layer_norm,
dropout_add_layer_norm_parallel_residual,
dropout_add_layer_norm_subset,
)
from flash_attn.ops.rms_norm import (
DropoutAddRMSNorm,
dropout_add_rms_norm,
dropout_add_rms_norm_parallel_residual,
dropout_add_rms_norm_subset,
)
try:
from apex.normalization import FusedRMSNorm
......@@ -20,28 +24,42 @@ except:
FusedRMSNorm, fused_rms_norm_affine = None, None
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize("has_colscale", [True, False])
# @pytest.mark.parametrize('has_colscale', [False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize("has_rowscale", [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
def test_dropout_layer_norm_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
......@@ -49,15 +67,16 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
......@@ -76,8 +95,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else:
rowscale = None
x0_scaled_pt = x0_pt
......@@ -98,16 +117,29 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.eps, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
out, dmask = our_layer_norm_func(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
......@@ -119,24 +151,33 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
out_ref.backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 3e-5
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
......@@ -144,8 +185,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
......@@ -172,27 +214,39 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_rowscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale,
is_rms_norm):
def test_dropout_layer_norm_prenorm_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
......@@ -200,15 +254,16 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
......@@ -227,8 +282,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else:
rowscale = None
x0_scaled_pt = x0_pt
......@@ -241,8 +296,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
model = our_layer_norm_cls(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model_ref.weight.copy_(model_pt.weight)
......@@ -250,24 +306,38 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.eps, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
out, residual, dmask = our_layer_norm_func(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt)).backward(g)
......@@ -275,24 +345,33 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
......@@ -300,8 +379,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
......@@ -310,8 +390,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
model = DropoutAddLayerNorm(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
......@@ -327,30 +408,36 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
......@@ -359,23 +446,28 @@ def test_dropout_layer_norm_subset_training(
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
~mask_batch_seqlen, 0
)
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
......@@ -402,8 +494,9 @@ def test_dropout_layer_norm_subset_training(
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
dtype=weight_dtype)
model = DropoutAddLayerNorm(
hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
......@@ -412,25 +505,42 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
layerscale=colscale,
x0_subset=x0_subset,
out_subset=out_subset,
rowscale_const=drop_path_scale,
out_numrows=out_numrows,
prenorm=False,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
......@@ -441,36 +551,50 @@ def test_dropout_layer_norm_subset_training(
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
......@@ -479,23 +603,28 @@ def test_dropout_layer_norm_subset_prenorm_training(
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
~mask_batch_seqlen, 0
)
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
......@@ -522,8 +651,9 @@ def test_dropout_layer_norm_subset_prenorm_training(
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
model = DropoutAddLayerNorm(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
......@@ -532,89 +662,139 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
layerscale=colscale,
x0_subset=x0_subset,
out_subset=out_subset,
rowscale_const=drop_path_scale,
out_numrows=out_numrows,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(
g
)
(out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
(out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
(
out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype))
+ residual_ref.mean(0, keepdim=True)
).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
@pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
@pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_training(
hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
else dropout_add_rms_norm_parallel_residual)
device = 'cuda'
our_layer_norm_func = (
dropout_add_layer_norm_parallel_residual
if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x1_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
......@@ -626,16 +806,22 @@ def test_dropout_layer_norm_parallel_residual_training(
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
bias0 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
bias1 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
......@@ -646,48 +832,77 @@ def test_dropout_layer_norm_parallel_residual_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
x0,
x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p))
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
eps=epsilon).to(dtype=input_dtype)
out0_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
bias1_pt, eps=epsilon).to(dtype=input_dtype)
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
out1_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype),
(hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
eps=epsilon).to(dtype=input_dtype)
out0_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
(hidden_size,), eps=epsilon).to(dtype=input_dtype)
out1_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
......@@ -706,61 +921,89 @@ def test_dropout_layer_norm_parallel_residual_training(
(out0_ref * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
@pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
@pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
else dropout_add_rms_norm_parallel_residual)
device = 'cuda'
our_layer_norm_func = (
dropout_add_layer_norm_parallel_residual
if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x1_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
......@@ -772,16 +1015,22 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
bias0 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
bias1 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
......@@ -792,54 +1041,86 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
x0,
x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p))
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
eps=epsilon).to(dtype=input_dtype)
out0_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
bias1_pt, eps=epsilon).to(dtype=input_dtype)
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
out1_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype),
(hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
eps=epsilon).to(dtype=input_dtype)
out0_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
(hidden_size,), eps=epsilon).to(dtype=input_dtype)
out1_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size
if tied_norm:
......@@ -853,39 +1134,56 @@ def test_dropout_layer_norm_parallel_residual_prenorm_training(
(out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
def test_dropout_layer_norm_randomness():
hidden_size = 256
dtype = torch.float32
dropout_p = 0.1
device = 'cuda'
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0 = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True
)
res = torch.randn_like(x0, dtype=dtype, requires_grad=True)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)
torch.random.manual_seed(42)
_, dmask0 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
model.eps, return_dropout_mask=True)
_, dmask0 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
# Subsequent call should have a different dropout mask
_, dmask1 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
model.eps, return_dropout_mask=True)
_, dmask1 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
torch.random.manual_seed(42)
# Resetting the seed, should get the same dropout mask
_, dmask2 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
model.eps, return_dropout_mask=True)
_, dmask2 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
assert not torch.equal(dmask0, dmask1)
assert torch.equal(dmask0, dmask2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment