Unverified Commit 021e1e62 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] Fix issue with microbatching + debug value caching (#2108)



* fix perf issue
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent e2f2a0b4
...@@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) ...@@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs)
model = torch.nn.Sequential( model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda() ).cuda()
NUM_ITERS = 18000 NUM_ITERS = 1800
elif layer == "transformer": elif layer == "transformer":
model = torch.nn.Sequential( model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"), te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"), te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda() ).cuda()
NUM_ITERS = 2000 NUM_ITERS = 200
NUM_INVOCATIONS_PER_ITER = 10
x = torch.randn(1, 1, 1).cuda() x = torch.randn(1, 1, 1).cuda()
...@@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) ...@@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs)
time_start = time.time() time_start = time.time()
for i in range(NUM_ITERS): for i in range(NUM_ITERS):
y = model(x) for _ in range(NUM_INVOCATIONS_PER_ITER):
y.sum().backward() y = model(x)
y.sum().backward()
if debug_tools_initialized: if debug_tools_initialized:
debug_api.step() debug_api.step()
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -1523,7 +1523,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1523,7 +1523,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = False debug = False
else: else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration() self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_enabled_in_this_iteration = debug
else:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration
return debug return debug
def no_debug_features_active(self, quantizers): def no_debug_features_active(self, quantizers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment