Commit 9666d263 authored by wenjh's avatar wenjh
Browse files

[DCU] Use ocp fp8(same as nvidia)



Use ocp fp8.
Workaround: test_cast_float8blockwise.cu link wrong std::max
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 80c5079c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"../util/cuda_runtime.h": "../util/hip_runtime.h", "../util/cuda_runtime.h": "../util/hip_runtime.h",
"common/util/cuda_driver.h": "common/util/hip_driver.h", "common/util/cuda_driver.h": "common/util/hip_driver.h",
"../util/cuda_driver.h": "../util/hip_driver.h", "../util/cuda_driver.h": "../util/hip_driver.h",
"./util/cuda_driver.h": "./util/hip_driver.h",
"common/util/cuda_nvml.h": "common/util/hip_nvml.h", "common/util/cuda_nvml.h": "common/util/hip_nvml.h",
"common/utils.cuh" : "common/utils_hip.cuh", "common/utils.cuh" : "common/utils_hip.cuh",
"common/transpose/cast_transpose.h" : "common/transpose/cast_transpose_hip.h", "common/transpose/cast_transpose.h" : "common/transpose/cast_transpose_hip.h",
...@@ -15,14 +16,17 @@ ...@@ -15,14 +16,17 @@
"/logging.h" : "/logging_hip.h", "/logging.h" : "/logging_hip.h",
"/system.h" : "/system_hip.h", "/system.h" : "/system_hip.h",
"<cuda_bf16.h>" : "<hip/hip_bf16.h>", "<cuda_bf16.h>" : "<hip/hip_bf16.h>",
"<cuda_fp8.h>" : "\"amd_detail/hip_float8.h\"", "<cuda_fp8.h>" : "<hip/hip_fp8.h>",
"CUfunc_cache" : "hipFuncCache_t", "CUfunc_cache" : "hipFuncCache_t",
"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>", "<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>",
"cudaLaunchKernelExC" : "hipLaunchKernelExC", "cudaLaunchKernelExC" : "hipLaunchKernelExC",
"cudaLaunchConfig_t" : "hipLaunchConfig_t", "cudaLaunchConfig_t" : "hipLaunchConfig_t",
"cudaLaunchAttributeClusterDimension" : "hipLaunchAttributeClusterDimension", "cudaLaunchAttributeClusterDimension" : "hipLaunchAttributeClusterDimension",
"cudaLaunchAttributeCooperative" : "hipLaunchAttributeCooperative", "cudaLaunchAttributeCooperative" : "hipLaunchAttributeCooperative",
"cudaLaunchAttribute": "hipLaunchAttribute" "cudaLaunchAttribute": "hipLaunchAttribute",
"__nv_fp8_e4m3": "__hip_fp8_e4m3",
"__nv_fp8_e5m2": "__hip_fp8_e5m2",
"nv_bfloat16": "__hip_bfloat16"
} }
} }
...@@ -50,7 +50,11 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale ...@@ -50,7 +50,11 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale
float input_type_max_val = Quantized_Limits<InputType>::max(); float input_type_max_val = Quantized_Limits<InputType>::max();
float quant_type_max_val = Quantized_Limits<OutputType>::max(); float quant_type_max_val = Quantized_Limits<OutputType>::max();
float eps = opts.amax_epsilon; float eps = opts.amax_epsilon;
#ifdef __HIP_PLATFORM_AMD__
amax = amax > eps? amax: eps;
#else
amax = std::max(amax, eps); amax = std::max(amax, eps);
#endif
float qscale = quant_type_max_val / amax; float qscale = quant_type_max_val / amax;
if (std::isinf(qscale)) { if (std::isinf(qscale)) {
qscale = input_type_max_val; qscale = input_type_max_val;
...@@ -101,7 +105,11 @@ void ref_quantize(const ProcessingMethod processing_method, const InputType* inp ...@@ -101,7 +105,11 @@ void ref_quantize(const ProcessingMethod processing_method, const InputType* inp
continue; continue;
} }
float val = static_cast<float>(input[y_pos * width + x_pos]); float val = static_cast<float>(input[y_pos * width + x_pos]);
#ifdef __HIP_PLATFORM_AMD__
amax = amax > std::abs(val)? amax: std::abs(val);
#else
amax = std::max(amax, std::abs(val)); amax = std::max(amax, std::abs(val));
#endif
} }
} }
...@@ -172,7 +180,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method ...@@ -172,7 +180,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
continue; continue;
} }
float val = static_cast<float>(input[y * width + x_pos]); float val = static_cast<float>(input[y * width + x_pos]);
#ifdef __HIP_PLATFORM_AMD__
amax = amax > std::abs(val)? amax: std::abs(val);
#else
amax = std::max(amax, std::abs(val)); amax = std::max(amax, std::abs(val));
#endif
} }
// We've calculated amax for a tile. Calculate scale and // We've calculated amax for a tile. Calculate scale and
...@@ -204,7 +216,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method ...@@ -204,7 +216,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
continue; continue;
} }
float val = static_cast<float>(input[x + y_pos * width]); float val = static_cast<float>(input[x + y_pos * width]);
#ifdef __HIP_PLATFORM_AMD__
amax = amax > std::abs(val)? amax: std::abs(val);
#else
amax = std::max(amax, std::abs(val)); amax = std::max(amax, std::abs(val));
#endif
} }
// We've calculated amax for a tile. Calculate scale and // We've calculated amax for a tile. Calculate scale and
......
...@@ -12,14 +12,8 @@ ...@@ -12,14 +12,8 @@
#include <random> #include <random>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#ifndef USE_ROCM
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include "amd_detail/hip_float8.h"
#endif
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
...@@ -56,15 +50,9 @@ using int32 = int32_t; ...@@ -56,15 +50,9 @@ using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
#ifndef USE_ROCM
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif //USE_ROCM
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
template <typename T> template <typename T>
...@@ -324,11 +312,7 @@ struct Numeric_Traits<fp8e4m3> { ...@@ -324,11 +312,7 @@ struct Numeric_Traits<fp8e4m3> {
static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0); static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0);
static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0); static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0); static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
#ifndef __HIP_PLATFORM_AMD__
static constexpr double maxNorm = 448.0; static constexpr double maxNorm = 448.0;
#else
static constexpr double maxNorm = 240.0;
#endif
static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity
static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS;
static constexpr int maxUnbiasedExponentAsFP32 = 8; static constexpr int maxUnbiasedExponentAsFP32 = 8;
......
...@@ -687,7 +687,7 @@ def _test_fp8_scale_update( ...@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
"""Expected absmax and FP8 scale""" """Expected absmax and FP8 scale"""
amax = ref.abs().amax() amax = ref.abs().amax()
max_val = { max_val = {
"forward": 448.0 if not IS_HIP_EXTENSION else 240.0, "forward": 448.0,
"backward": 57344.0, "backward": 57344.0,
}[stage] }[stage]
scale = (max_val / amax) / (2**margin) scale = (max_val / amax) / (2**margin)
......
...@@ -258,7 +258,7 @@ class TestFP8Recipe: ...@@ -258,7 +258,7 @@ class TestFP8Recipe:
# Compute scale # Compute scale
max_val = { max_val = {
"forward": 448.0 if not IS_HIP_EXTENSION else 240.0, "forward": 448.0,
"backward": 57344.0, "backward": 57344.0,
}[stage] }[stage]
ref_scale = (max_val / ref_amax) / (2**margin) ref_scale = (max_val / ref_amax) / (2**margin)
......
...@@ -329,10 +329,6 @@ else() ...@@ -329,10 +329,6 @@ else()
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(amd_detail/hip_float8.h
string_code_amd_detail_hip_float8_h)
make_string_header_from_file(amd_detail/hip_f8_impl.h
string_code_amd_detail_hip_f8_impl_h)
endif() endif()
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#include "amd_detail/hip_float8.h"
#include "common/common.h" #include "common/common.h"
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
...@@ -289,11 +288,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -289,11 +288,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2); assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#else
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#endif
comm_elements, _ub_comm, _stream_comm, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event); (cudaEvent_t)_comm_launch_event);
} else { } else {
......
...@@ -2033,53 +2033,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -2033,53 +2033,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
} }
} }
#ifdef __HIP_PLATFORM_AMD__
template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream, comm_launch_event);
}
template void reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<te_hip_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#else
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
...@@ -2125,7 +2078,6 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>( ...@@ -2125,7 +2078,6 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#endif
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
...@@ -2844,21 +2796,12 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in ...@@ -2844,21 +2796,12 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
num_aligned_elements_per_input, tot_input_size); num_aligned_elements_per_input, tot_input_size);
} }
#ifdef __HIP_PLATFORM_AMD__
template void reduce_fp8_in_bf16_out<te_hip_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<te_hip_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#else
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
#endif
template <int nvec> template <int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
......
...@@ -25,11 +25,7 @@ ...@@ -25,11 +25,7 @@
#include <vector> #include <vector>
#include "./nvtx.h" #include "./nvtx.h"
#ifdef __HIP_PLATFORM_AMD__
#include "./util/hip_driver.h"
#else
#include "./util/cuda_driver.h" #include "./util/cuda_driver.h"
#endif
#include "./util/logging.h" #include "./util/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -248,15 +244,9 @@ using int64 = int64_t; ...@@ -248,15 +244,9 @@ using int64 = int64_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
using int8 = int8_t; using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0; using fp8e8m0 = __nv_fp8_e8m0;
#endif #endif
...@@ -277,15 +267,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t) ...@@ -277,15 +267,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t) TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(half)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME(__hip_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e5m2)
#else
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#endif
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif #endif
...@@ -296,11 +280,7 @@ struct TypeExtrema; ...@@ -296,11 +280,7 @@ struct TypeExtrema;
template <> template <>
struct TypeExtrema<fp8e4m3> { struct TypeExtrema<fp8e4m3> {
#ifndef __HIP_PLATFORM_AMD__
static constexpr float max = 448.0f; static constexpr float max = 448.0f;
#else
static constexpr float max = 240.0f;
#endif
}; };
template <> template <>
......
...@@ -46,17 +46,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { ...@@ -46,17 +46,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
return HIP_R_32F; return HIP_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIP_R_16BF; return HIP_R_16BF;
#if HIP_VERSION >= 60300000
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return te_fp8_fnuz() ? HIP_R_8F_E4M3_FNUZ : HIP_R_8F_E4M3; return HIP_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return te_fp8_fnuz() ? HIP_R_8F_E5M2_FNUZ: HIP_R_8F_E5M2; return HIP_R_8F_E5M2;
#else
case DType::kFloat8E4M3:
return HIP_R_8F_E4M3_FNUZ;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2_FNUZ;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -863,10 +856,7 @@ protected: ...@@ -863,10 +856,7 @@ protected:
} }
#if HIP_VERSION >= 60300000 #if HIP_VERSION >= 60300000
auto fp8_filter = te_fp8_fnuz() auto fp8_filter = [](const hipDataType& val) {
? [](const hipDataType& val)
{ return (val != HIP_R_8F_E4M3 && val != HIP_R_8F_E5M2); }
: [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ); return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
}; };
#else #else
......
...@@ -329,15 +329,9 @@ using int32 = int32_t; ...@@ -329,15 +329,9 @@ using int32 = int32_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
using int8 = int8_t; using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
template <typename T> template <typename T>
struct TypeToDType; struct TypeToDType;
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = te_hip_fp8_e4m3;
using __nv_fp8_e5m2 = te_hip_fp8_e5m2;
#define __ldlu(x) __ldg(x) #define __ldlu(x) __ldg(x)
#endif #endif
......
...@@ -35,7 +35,7 @@ class Format(Enum): ...@@ -35,7 +35,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format FP8 tensors in the backward pass are in e5m2 format
""" """
E4M3 = _FormatHelper(max_fwd=448 if not IS_HIP_EXTENSION else 240.0, max_bwd=448 if not IS_HIP_EXTENSION else 240.0) E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
......
...@@ -32,11 +32,7 @@ const char* dtype_name(DType dtype) { ...@@ -32,11 +32,7 @@ const char* dtype_name(DType dtype) {
inline float fp8_dtype_max(DType dtype) { inline float fp8_dtype_max(DType dtype) {
switch (dtype) { switch (dtype) {
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
#ifndef __HIP_PLATFORM_AMD__
return 448; return 448;
#else
return 240;
#endif
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return 57344; return 57344;
default: default:
......
...@@ -187,9 +187,9 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -187,9 +187,9 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Compile source // Compile source
nvrtcProgram program; nvrtcProgram program;
#ifdef USE_ROCM #ifdef USE_ROCM
constexpr int num_headers = 4; constexpr int num_headers = 2;
const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h, string_code_amd_detail_hip_float8_h, string_code_amd_detail_hip_f8_impl_h}; const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
const char* include_names[num_headers] = {"utils_hip.cuh", "util/math.h", "amd_detail/hip_float8.h", "amd_detail/hip_f8_impl.h"}; const char* include_names[num_headers] = {"utils_hip.cuh", "util/math.h"};
#else #else
constexpr int num_headers = 2; constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h}; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
......
...@@ -982,13 +982,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float ...@@ -982,13 +982,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HIP_PLATFORM_AMD__
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
using int8 = int8_t; using int8 = int8_t;
...@@ -1003,11 +998,7 @@ struct Numeric_Traits; ...@@ -1003,11 +998,7 @@ struct Numeric_Traits;
template <> template <>
struct Numeric_Traits<fp8e4m3> { struct Numeric_Traits<fp8e4m3> {
static constexpr int maxUnbiasedExponent = 8; static constexpr int maxUnbiasedExponent = 8;
#ifndef __HIP_PLATFORM_AMD__
static constexpr double maxNorm = 448; static constexpr double maxNorm = 448;
#else
static constexpr double maxNorm = 240;
#endif
}; };
template <> template <>
......
...@@ -8,11 +8,8 @@ ...@@ -8,11 +8,8 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
...@@ -32,13 +29,8 @@ typedef enum { ...@@ -32,13 +29,8 @@ typedef enum {
} adamMode_t; } adamMode_t;
using MATH_T = float; using MATH_T = float;
#ifndef __HIP_PLATFORM_AMD__
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
using transformer_engine::DType; using transformer_engine::DType;
template <typename T> template <typename T>
......
...@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
torch.float16: torch.full( torch.float16: torch.full(
[1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32 [1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32
), ),
torch.uint8: torch.full([1], 448.0 if not IS_HIP_EXTENSION else 240.0, dtype=torch.float32), torch.uint8: torch.full([1], 448.0, dtype=torch.float32),
} }
self._scales = {} self._scales = {}
self.use_decoupled_grad = use_decoupled_grad self.use_decoupled_grad = use_decoupled_grad
......
...@@ -266,7 +266,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -266,7 +266,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
# Step 3: Update scales and scale_invs. # Step 3: Update scales and scale_invs.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3: if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0 if not IS_HIP_EXTENSION else 240.0 max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2: elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0 max_fp8 = 57344.0
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment