Unverified Commit 5b155fb3 authored by JimmyZhang12's avatar JimmyZhang12 Committed by GitHub
Browse files

Recomputation fixes with native fp8 (#646)



* fixes for recomputation
Signed-off-by: default avatarJimmy Zhang <jiemingz@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix onnx export [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* register op; fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarJimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarJimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2aee0591
......@@ -21,6 +21,7 @@ from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_to_fp8_noalloc,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_fp8_relu,
......
......@@ -22,12 +22,13 @@ def cast_to_fp8(
"""Cast input to FP8"""
if out is not None:
tex.cast_to_fp8_noalloc(
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
return None
......
......@@ -6,6 +6,7 @@
#include "extensions.h"
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
......
......@@ -7,6 +7,7 @@
#include <torch/script.h>
#include "extensions.h"
namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
......@@ -20,8 +21,8 @@ namespace {
at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
......@@ -33,6 +34,25 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input,
return output;
}
at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
cast_to_fp8_noalloc(input,
scale[fp8_tensor],
output,
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output;
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
......@@ -47,6 +67,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output;
}
at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -82,6 +103,7 @@ at::Tensor gelu_ts(at::Tensor input,
return output;
}
at::Tensor relu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -117,6 +139,7 @@ at::Tensor relu_ts(at::Tensor input,
return output;
}
at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -152,6 +175,7 @@ at::Tensor reglu_ts(at::Tensor input,
return output;
}
at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -187,6 +211,7 @@ at::Tensor geglu_ts(at::Tensor input,
return output;
}
at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -222,6 +247,7 @@ at::Tensor swiglu_ts(at::Tensor input,
return output;
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
......@@ -286,6 +312,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -312,6 +339,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output;
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -328,6 +356,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
return output;
}
at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
......@@ -352,6 +381,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output;
}
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
......@@ -366,8 +396,10 @@ at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
return output;
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts);
......
......@@ -36,6 +36,8 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
......@@ -173,7 +175,9 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -183,11 +187,12 @@ class _LayerNormLinear(torch.autograd.Function):
transpose_out=weight_t_fp8._data,
)
else:
weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
......
......@@ -41,6 +41,8 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from .. import cpp_extensions as tex
......@@ -219,7 +221,9 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused(
fc1_weight,
......@@ -238,18 +242,20 @@ class _LayerNormMLP(torch.autograd.Function):
transpose_out=fc2_weight_t_fp8._data,
)
else:
fc1_weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=fc1_weight_fp8._data,
)
fc1_weight_t_fp8 = None
fc2_weight_fp8._data = tex.cast_to_fp8(
tex.cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
out=fc2_weight_fp8._data,
)
fc2_weight_t_fp8 = None
......
......@@ -34,6 +34,8 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
fp8_gemm,
......@@ -155,7 +157,9 @@ class _Linear(torch.autograd.Function):
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -165,11 +169,12 @@ class _Linear(torch.autograd.Function):
transpose_out=weight_t_fp8._data,
)
else:
weight_fp8._data = cast_to_fp8(
cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=weight_fp8._data,
)
weight_t_fp8 = None
......
......@@ -130,6 +130,13 @@ def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8_noalloc"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
......@@ -393,10 +400,11 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
return result
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_to_fp8_noalloc_ts', onnx_cast_to_fp8_noalloc, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment