Commit 4099aa8e authored by yuguo's avatar yuguo
Browse files
parents c520cba3 96f9c6de
...@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef USE_ROCM #ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas
#endif #endif
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>()); py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
...@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int actual_b = b;
for (int i = 0; i < b - 1; i++) {
if (batch_indices[i + 1] < batch_indices[i]) {
actual_b = i + 1;
}
}
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k;
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k;
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) {
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i);
}
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) {
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i);
}
}
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// page_table: [b, max_pages_per_seq]
int page_size = max_seq_len / max_pages_per_seq;
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (new_token_offset + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (new_token_offset + i) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j);
}
}
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
#endif
...@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module ...@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer
...@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args] tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs) ctx.save_for_backward(*tensor_inputs)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
ctx.get_rng_state_tracker = get_rng_state_tracker ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.kwargs = kwargs ctx.kwargs = kwargs
return outputs return outputs
...@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True activation_recompute=True, recompute_phase=True
), fp8_autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
): ):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
...@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary" "none of output has requires_grad=True, this checkpoint() is not necessary"
) )
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch.autograd.backward(outputs_with_grad, args_with_grad) torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple( grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
...@@ -694,10 +702,15 @@ def checkpoint( ...@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase. # Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
def recompute_fn(*args, **kwargs): def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), ( with torch.autograd.enable_grad(), (
te_recompute_ctx te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
enabled=fp8, fp8_recipe=fp8_recipe
):
function(*args, **kwargs) function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass. # Initialize a new checkpoint frame for each new forward pass.
......
...@@ -91,6 +91,14 @@ def _make_graphed_callables( ...@@ -91,6 +91,14 @@ def _make_graphed_callables(
sample_args = (sample_args,) sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,) sample_kwargs = (sample_kwargs,)
# Check training/inference
is_training = all(c.training for c in callables)
if not is_training and any(c.training for c in callables):
assert False, (
"make_graphed_callables only supports when modules are all in training or all in"
" inference mode."
)
# Check sizes of args # Check sizes of args
if _order is None: if _order is None:
assert len(sample_args) == len(callables) assert len(sample_args) == len(callables)
...@@ -255,13 +263,16 @@ def _make_graphed_callables( ...@@ -255,13 +263,16 @@ def _make_graphed_callables(
outputs, _ = _tree_flatten(func(*args, **kwargs)) outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
grad_inputs = torch.autograd.grad( if is_training:
outputs=tuple(o for o in outputs if o.requires_grad), grad_inputs = torch.autograd.grad(
inputs=tuple(i for i in static_input_surface if i.requires_grad), outputs=tuple(o for o in outputs if o.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad),
only_inputs=True, grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
allow_unused=allow_unused_input, only_inputs=True,
) allow_unused=allow_unused_input,
)
else:
grad_inputs = None
del outputs, grad_inputs del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements, # The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow. # aimed at preventing warmup from altering the control flow.
...@@ -314,22 +325,23 @@ def _make_graphed_callables( ...@@ -314,22 +325,23 @@ def _make_graphed_callables(
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
with torch.cuda.graph(bwd_graph, pool=mempool): if is_training:
grad_inputs = torch.autograd.grad( with torch.cuda.graph(bwd_graph, pool=mempool):
outputs=tuple(o for o in static_outputs if o.requires_grad), grad_inputs = torch.autograd.grad(
inputs=tuple(i for i in static_input_surface if i.requires_grad), outputs=tuple(o for o in static_outputs if o.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), inputs=tuple(i for i in static_input_surface if i.requires_grad),
only_inputs=True, grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
allow_unused=allow_unused_input, only_inputs=True,
retain_graph=retain_graph_in_backward, allow_unused=allow_unused_input,
) retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = [] static_grad_inputs = []
grad_idx = 0 grad_idx = 0
for arg in static_input_surface: for arg in static_input_surface:
if arg.requires_grad: if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx]) static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1 grad_idx += 1
else: else:
...@@ -366,22 +378,23 @@ def _make_graphed_callables( ...@@ -366,22 +378,23 @@ def _make_graphed_callables(
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
with torch.cuda.graph(bwd_graph, pool=mempool): if is_training:
grad_inputs = torch.autograd.grad( with torch.cuda.graph(bwd_graph, pool=mempool):
outputs=tuple(o for o in static_outputs if o.requires_grad), grad_inputs = torch.autograd.grad(
inputs=tuple(i for i in static_input_surface if i.requires_grad), outputs=tuple(o for o in static_outputs if o.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), inputs=tuple(i for i in static_input_surface if i.requires_grad),
only_inputs=True, grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
allow_unused=allow_unused_input, only_inputs=True,
retain_graph=retain_graph_in_backward, allow_unused=allow_unused_input,
) retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that # Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern. # don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs = [] static_grad_inputs = []
grad_idx = 0 grad_idx = 0
for arg in static_input_surface: for arg in static_input_surface:
if arg.requires_grad: if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx]) static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1 grad_idx += 1
else: else:
...@@ -422,7 +435,10 @@ def _make_graphed_callables( ...@@ -422,7 +435,10 @@ def _make_graphed_callables(
# Copy values from new tensors into static tensors # Copy values from new tensors into static tensors
for i in range(len_user_args): for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): if (
isinstance(static_input_surface[i], torch.Tensor)
and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
):
static_input_surface[i].copy_(inputs[i]) static_input_surface[i].copy_(inputs[i])
# Replay forward graph # Replay forward graph
......
...@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
...@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
weightmat, weightmat,
...@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp_shape
...@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
weight.main_grad = main_grad if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
...@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication # Synchronize tensor parallel communication
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
...@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta, dbeta,
wgrad, wgrad,
grad_bias, grad_bias,
None, # use_bias
None, # eps None, # eps
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
...@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
weight_tensor, weight_tensor,
bias_tensor, bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
......
...@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias: torch.Tensor, ln_bias: torch.Tensor,
fc1_weight: torch.Tensor, fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
use_fc2_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
...@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM # FC1 GEMM
# There are 2 fussions possible: # There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision. # - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
...@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if is_grad_enabled:
if cpu_offloading: if cpu_offloading:
if fp8 and fc1_weight_final is not None: if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True) set_offloading_param(fc1_weight_final, "weight_offloading", True)
...@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias ctx.use_bias = fc2_bias is not None
ctx.use_fc2_bias = use_fc2_bias
ctx.use_bias = ctx.use_fc1_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp_shape
...@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params=None, # wgrad in high precision quantization_params=None, # wgrad in high precision
layout="NT", layout="NT",
grad=True, grad=True,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out) clear_tensor_data(act_out)
# bias computation # bias computation
...@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma, dgamma,
dbeta, dbeta,
fc1_wgrad, fc1_wgrad,
fc1_bias_grad if ctx.use_fc1_bias else None, fc1_bias_grad if fc1_bias is not None else None,
None, # use_fc1_bias
fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad if ctx.use_fc2_bias else None, fc2_bias_grad,
None, # use_fc2_bias
None, # eps None, # eps
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
...@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias, self.layer_norm_bias,
fc1_weight, fc1_weight,
fc1_bias, fc1_bias,
self.use_bias,
fc2_weight, fc2_weight,
fc2_bias, fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
......
...@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function): ...@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
...@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function): ...@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading:
weight = torch.nn.Parameter(weight, weight.requires_grad) if ctx.grad_added_to_main_grad:
weight.main_grad = main_grad weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment