Commit 2366716f authored by Jared Casper's avatar Jared Casper
Browse files

Error, not warn, if gradient_accumulation_fusion is requested but not available.

parent 55817ec9
...@@ -442,21 +442,22 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -442,21 +442,22 @@ class ColumnParallelLinear(torch.nn.Module):
if gradient_accumulation_fusion: if gradient_accumulation_fusion:
if not _grad_accum_fusion_available: if not _grad_accum_fusion_available:
# Basically, megatron.core users are expected to install APEX's raise RuntimeError(
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows: "ColumnParallelLinear was called with gradient_accumulation_fusion set "
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
# at the root of APEX repository. "module is not found. To use gradient_accumulation_fusion you must "
warnings.warn( "install APEX with --cpp_ext and --cuda_ext. For example: "
"`gradient_accumulation_fusion` is set to `True` but " "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not " "Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"found. Thus `gradient_accumulation_fusion` set to `False`. " "gradient accumulation fusion."
"Note that the extension requires CUDA>=11."
) )
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion self.gradient_accumulation_fusion = gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError("`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.") raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
"cannot be enabled at the same time."
)
def forward(self, input_): def forward(self, input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment