Unverified Commit f56d2996 authored by lkchen's avatar lkchen Committed by GitHub
Browse files

[Misc] Respect `no_use_tqdm_on_load` flag while capturing CUDA graph (#20834)


Signed-off-by: default avatarLinkun <github@lkchen.net>
parent 147afb44
...@@ -2270,8 +2270,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2270,8 +2270,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes) compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank(): if is_global_first_rank():
compilation_cases = tqdm(list(compilation_cases), compilation_cases = tqdm(
desc="Capturing CUDA graph shapes") list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases: for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for _ in range( for _ in range(
......
...@@ -1587,6 +1587,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1587,6 +1587,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
compilation_cases = tqdm( compilation_cases = tqdm(
list(compilation_cases), list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes") desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases: for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = ( attn_metadata = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment