Unverified Commit f126a04f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Add tests for cuda graph capture (#144)



* Add tests for cuda graph capture
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add sanity test and address reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7396c527
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import contextlib import contextlib
from typing import List, Optional from typing import List, Optional
import pytest import pytest
import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -675,3 +676,102 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -675,3 +676,102 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
if graph:
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
else:
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
optimizer.zero_grad(set_to_none=True)
static_output = block(static_input)
loss = loss_fn(static_output, static_target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
grads = [static_input.grad]
for p in block.parameters():
if p.requires_grad:
grads.append(p.grad)
with torch.no_grad():
output = static_output.clone()
return output, grads
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_cuda_graph(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
)
.to(dtype=dtype)
.cuda()
)
graphed_block = copy.deepcopy(block)
out, _ = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graph_out, _ = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
# Check output.
assert_allclose(out, graph_out, 9e-1)
...@@ -100,6 +100,54 @@ def _disable_wgrads(block): ...@@ -100,6 +100,54 @@ def _disable_wgrads(block):
p.requires_grad = False p.requires_grad = False
def _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
...@@ -608,3 +656,42 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_ ...@@ -608,3 +656,42 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_
) )
_test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment