Unverified Commit ec0d40d6 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

PyTorch API numeric tests (#215)



* LayerNormMLP numeric test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* DotProductAttention numeric test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ce3980c8
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import math
import os
import contextlib
from typing import List, Optional
......@@ -17,8 +18,11 @@ from torch.cuda import _lazy_call, device as device_ctx_manager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
)
from transformer_engine.pytorch import (
DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer
)
from transformer_engine.pytorch import Linear, LayerNormLinear, TransformerLayer
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
......@@ -192,6 +196,120 @@ def get_dummy_cuda_rng_tracker():
return _DUMMY_CUDA_RNG_STATE_TRACKER
class TorchScaledMaskedSoftmax(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
dtype = inp.dtype
inp = inp.float()
if scale is not None:
inp = inp * scale
mask_output = attention_mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = probs.to(dtype)
return probs
class TorchDotProductAttention(torch.nn.Module):
def __init__(
self,
kv_channels: int,
attention_dropout: float = 0.0,
) -> None:
super().__init__()
self.norm_factor = math.sqrt(kv_channels)
self.scale_mask_softmax = TorchScaledMaskedSoftmax()
self.attention_dropout = torch.nn.Dropout(attention_dropout)
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1)
return context_layer
class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, eps: float, bias: bool = True):
super().__init__()
......@@ -217,24 +335,24 @@ class TorchMHA(nn.Module):
return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
class TorchMLP(nn.Module):
def __init__(self, hidden_size: int):
class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5):
super().__init__()
self.fc1 = nn.Linear(hidden_size, 4 * hidden_size)
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.fc1 = nn.Linear(hidden_size, ffn_hidden_size)
self.gelu = nn.GELU(approximate="tanh")
self.fc2 = nn.Linear(4 * hidden_size, hidden_size)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
return self.fc2(self.gelu(self.fc1(self.ln(x))))
class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int):
super().__init__()
self.ln_1 = nn.LayerNorm(hidden_size, eps=eps)
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_2 = nn.LayerNorm(hidden_size, eps=eps)
self.mlp = TorchMLP(hidden_size)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.resid_attn_dropout = nn.Dropout(0.1)
self.resid_mlp_dropout = nn.Dropout(0.1)
......@@ -243,11 +361,10 @@ class TorchGPT(nn.Module):
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln_1(x)
a = self.ln(x)
b, _ = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
m = self.ln_2(x)
n = self.mlp(m)
n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n)
return x
......@@ -535,10 +652,10 @@ def test_gpt_accuracy(dtype, bs, model):
# Share params
with torch.no_grad():
torch_gpt.ln_1.weight = Parameter(
torch_gpt.ln.weight = Parameter(
te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
)
torch_gpt.ln_1.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
te_gpt.self_attention.layernorm_qkv.weight.clone()
)
......@@ -551,12 +668,12 @@ def test_gpt_accuracy(dtype, bs, model):
torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
te_gpt.self_attention.proj.bias.clone()
)
torch_gpt.ln_2.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
torch_gpt.ln_2.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
torch_gpt.mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
torch_gpt.mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
torch_gpt.mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
torch_gpt.mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)
......@@ -588,6 +705,64 @@ def _test_granular_accuracy(block, bs, dtype, config):
return outputs
def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool()
query, key, value = [
torch.randn(config.seq_len, bs, config.num_attention_heads,
config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)]
query.retain_grad()
key.retain_grad()
value.retain_grad()
out = block(query, key, value, mask)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
return [out, query.grad, key.grad, value.grad]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dpa_accuracy(dtype, bs, model):
config = model_configs[model]
te_dpa = (
DotProductAttention(
config.num_attention_heads,
config.embed,
0.1, # dropout
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_dpa = (
TorchDotProductAttention(
config.embed,
0.1, # dropout
)
.to(dtype=dtype)
.cuda()
.eval()
)
te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -678,6 +853,51 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_layernorm_mlp_accuracy(dtype, bs, model):
config = model_configs[model]
te_ln_mlp = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_ln_mlp = (
TorchLayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment