Unverified Commit b2b3fbe7 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Tighten tolerances for graph capture test (#153)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c0451dd1
...@@ -685,51 +685,46 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): ...@@ -685,51 +685,46 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
loss_fn = torch.nn.MSELoss() loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1) optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture. # Placeholders used for graph capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype) static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
real_input = torch.rand_like(static_input) real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target) real_target = torch.rand_like(static_target)
# Basic training loop.
def train_step():
optimizer.zero_grad(set_to_none=False)
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
return out
# Warmup steps in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
train_step()
torch.cuda.current_stream().wait_stream(s)
# Capture graph.
g = None
static_output = None
if graph: if graph:
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g): with torch.cuda.graph(g):
static_output = block(static_input) static_output = train_step()
static_loss = loss_fn(static_output, static_target)
static_loss.backward() # Run with new data.
optimizer.step() with torch.no_grad():
static_input.copy_(real_input)
# Fills the graph's input memory with new data to compute on static_target.copy_(real_target)
with torch.no_grad(): if graph:
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay() g.replay()
else: else:
with torch.no_grad(): static_output = train_step()
static_input.copy_(real_input)
static_target.copy_(real_target)
optimizer.zero_grad(set_to_none=True)
static_output = block(static_input)
loss = loss_fn(static_output, static_target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
grads = [static_input.grad] grads = [static_input.grad]
for p in block.parameters(): for p in block.parameters():
...@@ -770,8 +765,12 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -770,8 +765,12 @@ def test_gpt_cuda_graph(dtype, bs, model):
) )
graphed_block = copy.deepcopy(block) graphed_block = copy.deepcopy(block)
out, _ = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False) out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graph_out, _ = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True) graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
params = list(block.parameters())
graphed_params = list(graphed_block.parameters())
# Check output. # Check that results match
assert_allclose(out, graph_out, 9e-1) assert_allclose(out, graphed_out, 1e-3)
assert_allclose(params, graphed_params, 1e-3)
assert_allclose(grads, graphed_grads, 1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment