Unverified Commit 7950a82d authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Warn only when `gradient_accumulation_fusion` is `True` and...

[transformer] Warn only when `gradient_accumulation_fusion` is `True` and `fused_weight_gradient_mlp_cuda` is missing (#1317)
parent a56e88dc
...@@ -38,20 +38,11 @@ from apex.transformer.log_util import get_transformer_logger ...@@ -38,20 +38,11 @@ from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__) _logger = get_transformer_logger(__name__)
_grad_accum_fusion_available = False _grad_accum_fusion_available = True
try: try:
import fused_weight_gradient_mlp_cuda import fused_weight_gradient_mlp_cuda
except ImportError: except ImportError:
# Basically, apex.transformer module users are expected to install APEX's _grad_accum_fusion_available = False
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
_logger.warning(
"`fused_weight_gradient_mlp_cuda` module not found. "
"gradient accumulation fusion with weight gradient computation disabled."
)
else:
_grad_accum_fusion_available = True
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
...@@ -431,7 +422,21 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -431,7 +422,21 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and not no_async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.gradient_accumulation_fusion = gradient_accumulation_fusion and _grad_accum_fusion_available if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
import warnings
warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
"found. Thus `gradient_accumulation_fusion` set to `False`. "
"Note that the extension requires CUDA>=11."
)
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment