Unverified Commit 5b155fb3 authored by JimmyZhang12's avatar JimmyZhang12 Committed by GitHub
Browse files

Recomputation fixes with native fp8 (#646)



* fixes for recomputation
Signed-off-by: default avatarJimmy Zhang <jiemingz@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix onnx export [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* register op; fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarJimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarJimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2aee0591
...@@ -21,6 +21,7 @@ from .cpu_offload import get_cpu_offload_context ...@@ -21,6 +21,7 @@ from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions # Register custom op symbolic ONNX functions
from .te_onnx_extensions import ( from .te_onnx_extensions import (
onnx_cast_to_fp8, onnx_cast_to_fp8,
onnx_cast_to_fp8_noalloc,
onnx_cast_from_fp8, onnx_cast_from_fp8,
onnx_fp8_gelu, onnx_fp8_gelu,
onnx_fp8_relu, onnx_fp8_relu,
......
...@@ -22,12 +22,13 @@ def cast_to_fp8( ...@@ -22,12 +22,13 @@ def cast_to_fp8(
"""Cast input to FP8""" """Cast input to FP8"""
if out is not None: if out is not None:
tex.cast_to_fp8_noalloc( torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
out, out,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
otype otype
) )
return None return None
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "extensions.h" #include "extensions.h"
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
at::Tensor amax, at::Tensor amax,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <torch/script.h> #include <torch/script.h>
#include "extensions.h" #include "extensions.h"
namespace { namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) { transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) { if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
...@@ -20,8 +21,8 @@ namespace { ...@@ -20,8 +21,8 @@ namespace {
at::Tensor cast_to_fp8_ts(const at::Tensor &input, at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
const at::Tensor &amax, at::Tensor amax,
const at::Tensor &scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype) { int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
...@@ -33,6 +34,25 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, ...@@ -33,6 +34,25 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
cast_to_fp8_noalloc(input,
scale[fp8_tensor],
output,
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output;
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input, at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv, const at::Tensor &scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
...@@ -47,6 +67,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, ...@@ -47,6 +67,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor gelu_ts(at::Tensor input, at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -82,6 +103,7 @@ at::Tensor gelu_ts(at::Tensor input, ...@@ -82,6 +103,7 @@ at::Tensor gelu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor relu_ts(at::Tensor input, at::Tensor relu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -117,6 +139,7 @@ at::Tensor relu_ts(at::Tensor input, ...@@ -117,6 +139,7 @@ at::Tensor relu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor reglu_ts(at::Tensor input, at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -152,6 +175,7 @@ at::Tensor reglu_ts(at::Tensor input, ...@@ -152,6 +175,7 @@ at::Tensor reglu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor geglu_ts(at::Tensor input, at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -187,6 +211,7 @@ at::Tensor geglu_ts(at::Tensor input, ...@@ -187,6 +211,7 @@ at::Tensor geglu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor swiglu_ts(at::Tensor input, at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -222,6 +247,7 @@ at::Tensor swiglu_ts(at::Tensor input, ...@@ -222,6 +247,7 @@ at::Tensor swiglu_ts(at::Tensor input,
return output; return output;
} }
at::Tensor te_gemm_ts(at::Tensor A, at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
int64_t A_fp8_tensor, int64_t A_fp8_tensor,
...@@ -286,6 +312,7 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -286,6 +312,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
return D; return D;
} }
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -312,6 +339,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -312,6 +339,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -328,6 +356,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -328,6 +356,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
double eps, double eps,
...@@ -352,6 +381,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -352,6 +381,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
double eps, double eps,
...@@ -366,8 +396,10 @@ at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, ...@@ -366,8 +396,10 @@ at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
return output; return output;
} }
TORCH_LIBRARY(tex_ts, m) { TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts); m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts); m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("gelu_ts", &gelu_ts); m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts); m.def("relu_ts", &relu_ts);
......
...@@ -36,6 +36,8 @@ from ..distributed import ( ...@@ -36,6 +36,8 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -173,7 +175,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -173,7 +175,9 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
) )
if is_grad_enabled: if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -183,11 +187,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -183,11 +187,12 @@ class _LayerNormLinear(torch.autograd.Function):
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
) )
else: else:
weight_fp8._data = tex.cast_to_fp8( tex.cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
out=weight_fp8._data,
) )
weight_t_fp8 = None weight_t_fp8 = None
......
...@@ -41,6 +41,8 @@ from ..distributed import ( ...@@ -41,6 +41,8 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
) )
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
...@@ -219,7 +221,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -219,7 +221,9 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
) )
if is_grad_enabled: if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Fused cast-transpose kernels # Fused cast-transpose kernels
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
fc1_weight, fc1_weight,
...@@ -238,18 +242,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -238,18 +242,20 @@ class _LayerNormMLP(torch.autograd.Function):
transpose_out=fc2_weight_t_fp8._data, transpose_out=fc2_weight_t_fp8._data,
) )
else: else:
fc1_weight_fp8._data = tex.cast_to_fp8( tex.cast_to_fp8(
fc1_weight, fc1_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
out=fc1_weight_fp8._data,
) )
fc1_weight_t_fp8 = None fc1_weight_t_fp8 = None
fc2_weight_fp8._data = tex.cast_to_fp8( tex.cast_to_fp8(
fc2_weight, fc2_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
out=fc2_weight_fp8._data,
) )
fc2_weight_t_fp8 = None fc2_weight_t_fp8 = None
......
...@@ -34,6 +34,8 @@ from ..distributed import ( ...@@ -34,6 +34,8 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
fp8_gemm, fp8_gemm,
...@@ -155,7 +157,9 @@ class _Linear(torch.autograd.Function): ...@@ -155,7 +157,9 @@ class _Linear(torch.autograd.Function):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
) )
if is_grad_enabled: if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -165,11 +169,12 @@ class _Linear(torch.autograd.Function): ...@@ -165,11 +169,12 @@ class _Linear(torch.autograd.Function):
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
) )
else: else:
weight_fp8._data = cast_to_fp8( cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
out=weight_fp8._data,
) )
weight_t_fp8 = None weight_t_fp8 = None
......
...@@ -130,6 +130,13 @@ def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): ...@@ -130,6 +130,13 @@ def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
return quantize(g, inputs, scale_inv, fp8_tensor) return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8_noalloc"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i") @symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8""" """ONNX graph for cast_from_fp8"""
...@@ -393,10 +400,11 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): ...@@ -393,10 +400,11 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
result = g.op("Mul", weight, normalized_input) result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
return result return result
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_to_fp8_noalloc_ts', onnx_cast_to_fp8_noalloc, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER) register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER) register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment