Commit fbee8990 authored by yuguo's avatar yuguo
Browse files

[DCU] fix fp8

parent 57deee08
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
# CXX=hipcc make build && cd build && cmake ../
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
option(USE_CUDA "Use CUDA" ON) option(USE_CUDA "Use CUDA" ON)
......
...@@ -58,7 +58,7 @@ void compute_amax_scale_ref(const InputType *data, ...@@ -58,7 +58,7 @@ void compute_amax_scale_ref(const InputType *data,
float scale = 1.f; float scale = 1.f;
float scale_inv = 1.f; float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) { if (std::isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale; *scale_ptr = scale;
*scale_inv_ptr = scale_inv; *scale_inv_ptr = scale_inv;
return; return;
...@@ -69,11 +69,11 @@ void compute_amax_scale_ref(const InputType *data, ...@@ -69,11 +69,11 @@ void compute_amax_scale_ref(const InputType *data,
// The amax is too small that the scale becoming infinite in FP32. In other word, // The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32. // the scale is not representable in FP32.
if (isinf(scale)) { if (std::isinf(scale)) {
scale = std::numeric_limits<float>::max(); scale = std::numeric_limits<float>::max();
} }
if (isnan(scale)) { if (std::isnan(scale)) {
scale = 1.f; scale = 1.f;
} }
......
...@@ -69,7 +69,7 @@ void scale_block(const ProcessingMethod processing_method, ...@@ -69,7 +69,7 @@ void scale_block(const ProcessingMethod processing_method,
elt *= static_cast<float>(grad[idx]); elt *= static_cast<float>(grad[idx]);
} }
dbias[j] += elt; dbias[j] += elt;
if (isinf(elt) || isnan(elt)) { if (std::isinf(elt) || std::isnan(elt)) {
continue; continue;
} }
amax = std::max(amax, std::abs(elt)); amax = std::max(amax, std::abs(elt));
......
...@@ -62,7 +62,7 @@ void compute_amax_scale_ref(const InputType *data, ...@@ -62,7 +62,7 @@ void compute_amax_scale_ref(const InputType *data,
float scale = 1.f; float scale = 1.f;
float scale_inv = 1.f; float scale_inv = 1.f;
if (isinf(clamp_amax) || clamp_amax == 0.f) { if (std::isinf(clamp_amax) || clamp_amax == 0.f) {
*scale_ptr = scale; *scale_ptr = scale;
*scale_inv_ptr = scale_inv; *scale_inv_ptr = scale_inv;
return; return;
...@@ -73,11 +73,11 @@ void compute_amax_scale_ref(const InputType *data, ...@@ -73,11 +73,11 @@ void compute_amax_scale_ref(const InputType *data,
// The amax is too small that the scale becoming infinite in FP32. In other word, // The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32. // the scale is not representable in FP32.
if (isinf(scale)) { if (std::isinf(scale)) {
scale = std::numeric_limits<float>::max(); scale = std::numeric_limits<float>::max();
} }
if (isnan(scale)) { if (std::isnan(scale)) {
scale = 1.f; scale = 1.f;
} }
......
...@@ -111,16 +111,16 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c ...@@ -111,16 +111,16 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
DType dtype = TypeInfo<D_Type>::dtype; DType dtype = TypeInfo<D_Type>::dtype;
// pytorch tensor storage is row-major while cublas/rocblas is column-major // pytorch tensor storage is row-major while cublas/rocblas is column-major
Tensor A({ k, m }, atype); Tensor A("A", { k, m }, atype);
Tensor B({ n, k }, btype); Tensor B("B", { n, k }, btype);
Tensor D({ n, m }, dtype); Tensor D("D", { n, m }, dtype);
Tensor bias; Tensor bias;
if(use_bias){ if(use_bias){
bias = Tensor({m}, bias_type); bias = Tensor("bias", {m}, bias_type);
} }
Tensor pre_gelu_out; Tensor pre_gelu_out;
if(use_gelu){ if(use_gelu){
pre_gelu_out = Tensor({ n, m }, gelu_type); pre_gelu_out = Tensor("pre_gelu_out", { n, m }, gelu_type);
} }
//initialize the data and scale inv of A, B //initialize the data and scale inv of A, B
...@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c ...@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
} }
#endif #endif
Tensor Workspace({ 33554432 }, DType::kByte); Tensor Workspace("Workspace", { 33554432 }, DType::kByte);
//perform the gemm in GPU //perform the gemm in GPU
nvte_cublas_gemm(A.data(), nvte_cublas_gemm(A.data(),
...@@ -180,11 +180,11 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c ...@@ -180,11 +180,11 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
} }
float ref_amax_d; float ref_amax_d;
compute_ref<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( compute_ref<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
A.cpu_dptr<A_Type>(), A.rowwise_cpu_dptr<A_Type>(),
B.cpu_dptr<B_Type>(), B.rowwise_cpu_dptr<B_Type>(),
A.scale_inv(), A.rowwise_scale_inv(),
B.scale_inv(), B.rowwise_scale_inv(),
use_bias? bias.cpu_dptr<Bias_Type>(): nullptr, use_bias? bias.rowwise_cpu_dptr<Bias_Type>(): nullptr,
D.scale(), D.scale(),
m, k, n, m, k, n,
ref_D.get(), ref_D.get(),
......
...@@ -143,7 +143,7 @@ void generate_data(InputType * const data, ...@@ -143,7 +143,7 @@ void generate_data(InputType * const data,
if (is_negative) { if (is_negative) {
val = -val; val = -val;
} }
data[idx] = static_cast<InputType>(val); data[idx] = static_cast<InputType>(static_cast<float>(val));
} }
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <random> #include <random>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -78,11 +79,17 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const ...@@ -78,11 +79,17 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
} else { } else {
if (use_cudnn){ if (use_cudnn){
compute_t g = static_cast<compute_t>(0.f); compute_t g = static_cast<compute_t>(0.f);
#ifndef __HIP_PLATFORM_AMD__
InputType gi = gamma; InputType gi = gamma;
if (zero_centered_gamma) { if (zero_centered_gamma) {
gi = gi + static_cast<InputType>(1.f); gi = gi + static_cast<InputType>(1.f);
} }
g = static_cast<compute_t>(gi); g = static_cast<compute_t>(gi);
#else
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
}
#endif
return g; return g;
} else { } else {
compute_t g = static_cast<compute_t>(gamma); compute_t g = static_cast<compute_t>(gamma);
......
...@@ -133,7 +133,11 @@ void compute_ref_stats(NormType norm_type, ...@@ -133,7 +133,11 @@ void compute_ref_stats(NormType norm_type,
compute_t current = static_cast<compute_t>(data[i * H + j]); compute_t current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m); sum_sq += (current - m) * (current - m);
} }
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon); rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
} }
} }
......
...@@ -584,8 +584,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const ...@@ -584,8 +584,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol); const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol);
if (i != N) { if (i != N) {
#ifndef __HIP_PLATFORM_AMD__
const double t = static_cast<double>(test_data[i]); const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]); const double r = static_cast<double>(ref_data[i]);
#else
const double t = static_cast<double>(static_cast<float>(test_data[i]));
const double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in " ASSERT_FALSE(true) << "Error in tensor " << name << " in "
<< direction << " direction." << std::endl << direction << " direction." << std::endl
...@@ -607,8 +612,13 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref ...@@ -607,8 +612,13 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref
void compareResults(const std::string &name, const float test, const float ref, void compareResults(const std::string &name, const float test, const float ref,
double atol, double rtol) { double atol, double rtol) {
#ifndef __HIP_PLATFORM_AMD__
double t = static_cast<double>(test); double t = static_cast<double>(test);
double r = static_cast<double>(ref); double r = static_cast<double>(ref);
#else
double t = static_cast<double>(static_cast<float>(test));
double r = static_cast<double>(static_cast<float>(ref));
#endif
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch) << "Error in " << name << std::endl ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r; << "Mismatch: " << t << " vs " << r;
...@@ -692,7 +702,11 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -692,7 +702,11 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
std::uniform_real_distribution<> dis(-2.0, 1.0); std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = idx_min; i < idx_max; ++i) { for (int i = idx_min; i < idx_max; ++i) {
#ifndef __HIP_PLATFORM_AMD__
data[i] = static_cast<T>(dis(gen_local)); data[i] = static_cast<T>(dis(gen_local));
#else
data[i] = static_cast<T>(static_cast<float>(dis(gen_local)));
#endif
} }
} }
gen->discard(size); gen->discard(size);
......
...@@ -61,6 +61,7 @@ using bf16 = nv_bfloat16; ...@@ -61,6 +61,7 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using bf16 = __hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2; using fp8e5m2 = te_hip_fp8_e5m2;
#endif //USE_ROCM #endif //USE_ROCM
...@@ -325,7 +326,11 @@ struct Numeric_Traits<fp8e4m3> { ...@@ -325,7 +326,11 @@ struct Numeric_Traits<fp8e4m3> {
static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0); static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0);
static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0); static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0); static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
#ifndef __HIP_PLATFORM_AMD__
static constexpr double maxNorm = 448.0; static constexpr double maxNorm = 448.0;
#else
static constexpr double maxNorm = 240.0;
#endif
static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity
static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS;
static constexpr int maxUnbiasedExponentAsFP32 = 8; static constexpr int maxUnbiasedExponentAsFP32 = 8;
......
...@@ -10,7 +10,7 @@ import transformer_engine_torch as tex ...@@ -10,7 +10,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax
from torch.utils.cpp_extension import IS_HIP_EXTENSION
input_size_pairs = [ input_size_pairs = [
(7777 * 77, 555 * 555), (7777 * 77, 555 * 555),
...@@ -224,7 +224,7 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, ...@@ -224,7 +224,7 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55]) @pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) @pytest.mark.parametrize("max_fp8", [448.0 if not IS_HIP_EXTENSION else 240.0, 57344.0])
@pytest.mark.parametrize("pow_2_scales", [False, True]) @pytest.mark.parametrize("pow_2_scales", [False, True])
@pytest.mark.parametrize("epsilon", [0.0, 100.0]) @pytest.mark.parametrize("epsilon", [0.0, 100.0])
def test_multi_tensor_compute_scale_and_scale_inv( def test_multi_tensor_compute_scale_and_scale_inv(
......
...@@ -165,6 +165,7 @@ else() ...@@ -165,6 +165,7 @@ else()
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp normalization/common.cpp
normalization/layernorm/ln_api.cpp normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
......
...@@ -8,7 +8,7 @@ import warnings ...@@ -8,7 +8,7 @@ import warnings
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from torch.utils.cpp_extension import IS_HIP_EXTENSION
class _FormatHelper(NamedTuple): class _FormatHelper(NamedTuple):
""" """
...@@ -34,7 +34,7 @@ class Format(Enum): ...@@ -34,7 +34,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format FP8 tensors in the backward pass are in e5m2 format
""" """
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) E4M3 = _FormatHelper(max_fwd=448 if not IS_HIP_EXTENSION else 240.0, max_bwd=448 if not IS_HIP_EXTENSION else 240.0)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
......
...@@ -36,7 +36,11 @@ const char* dtype_name(DType dtype) { ...@@ -36,7 +36,11 @@ const char* dtype_name(DType dtype) {
inline float fp8_dtype_max(DType dtype) { inline float fp8_dtype_max(DType dtype) {
switch (dtype) { switch (dtype) {
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
#ifndef __HIP_PLATFORM_AMD__
return 448; return 448;
#else
return 240;
#endif
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return 57344; return 57344;
default: default:
......
...@@ -1002,7 +1002,11 @@ struct Numeric_Traits; ...@@ -1002,7 +1002,11 @@ struct Numeric_Traits;
template <> template <>
struct Numeric_Traits<fp8e4m3> { struct Numeric_Traits<fp8e4m3> {
static constexpr int maxUnbiasedExponent = 8; static constexpr int maxUnbiasedExponent = 8;
#ifndef __HIP_PLATFORM_AMD__
static constexpr double maxNorm = 448; static constexpr double maxNorm = 448;
#else
static constexpr double maxNorm = 240;
#endif
}; };
template <> template <>
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def get_fp8_meta(fp8_tensor): def get_fp8_meta(fp8_tensor):
"""FP8 metadata getter.""" """FP8 metadata getter."""
...@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
torch.float16: torch.full( torch.float16: torch.full(
[1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32 [1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32
), ),
torch.uint8: torch.full([1], 448.0, dtype=torch.float32), torch.uint8: torch.full([1], 448.0 if not IS_HIP_EXTENSION else 240.0, dtype=torch.float32),
} }
self._scales = {} self._scales = {}
self.use_decoupled_grad = use_decoupled_grad self.use_decoupled_grad = use_decoupled_grad
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Helper functions for using fp8 tensors as weights""" """Helper functions for using fp8 tensors as weights"""
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
...@@ -243,7 +244,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -243,7 +244,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# Step 3: Update scales and scale_invs. # Step 3: Update scales and scale_invs.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3: if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0 max_fp8 = 448.0 if not IS_HIP_EXTENSION else 240.0
elif fp8_dtype == tex.DType.kFloat8E5M2: elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0 max_fp8 = 57344.0
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment