"tests/git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "1491c94c9d742d2d7bfccfc57b3944d66e5e048a"
Commit ea272d4a authored by yuguo's avatar yuguo
Browse files

[DCU] support for ROCm FP8 FNUZ and OCP formats

parent a248abb6
...@@ -61,10 +61,9 @@ using bf16 = nv_bfloat16; ...@@ -61,10 +61,9 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using bf16 = __hip_bfloat16; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>; using fp8e5m2 = te_hip_fp8_e5m2;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>; #endif //USE_ROCM
#endif
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
template <typename T> template <typename T>
......
...@@ -4,10 +4,108 @@ ...@@ -4,10 +4,108 @@
* License for AMD contributions = MIT. See LICENSE for more information * License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/ ************************************************************************/
#pragma once #pragma once
// FP8 header version 0.3, 2021/05/11
#ifdef __HIPCC__
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#if HIP_VERSION >= 60200000
#include <hip/hip_fp8.h>
#if HIP_VERSION >= 60300000
#if !defined(__HIP_DEVICE_COMPILE__)
#include <optional>
#include "../util/string.h"
/* Platforms that have both MI300 family and other families GPUs are unknown and not supported.
* Thus, FP8 format is selected once by the current (any) GPU architecture.
*/
static bool _te_check_fp8_fnuz() {
hipDeviceProp_t prop;
hipError_t res= hipGetDeviceProperties(&prop, 0);
if (res != hipSuccess) {
//TODO: better error out system
throw std::runtime_error(transformer_engine::concat_strings(
"hipGetDeviceProperties failed with error: ", hipGetErrorString(res)));
}
return prop.major == 9 && prop.minor == 4;
}
static inline bool te_fp8_fnuz() {
static std::optional<bool> use_fnuz;
if (!use_fnuz.has_value()) {
use_fnuz = _te_check_fp8_fnuz();
}
return use_fnuz.value();
}
/* Device methods in _te_hip_fp8 are dummy and are needed for compilation
* because HIPCC compiles __device__ and __global__ functions for host.
* The results are discarded so those methods are declared but not defined
*/
template<typename FNUZ, typename OCP>
union _te_hip_fp8 {
FNUZ fnuz;
OCP ocp;
__host__ __device__ _te_hip_fp8<FNUZ, OCP>() = default;
__host__ operator float() const {
return te_fp8_fnuz() ? fnuz.operator float() : ocp.operator float();
}
__device__ operator float() const;
__host__ _te_hip_fp8<FNUZ, OCP>(const float& v) {
if (te_fp8_fnuz()) fnuz=v; else ocp=v;
}
__device__ _te_hip_fp8<FNUZ, OCP>(const float& v);
};
typedef _te_hip_fp8<__hip_fp8_e4m3_fnuz, __hip_fp8_e4m3> _te_hip_fp8_e4m3;
typedef _te_hip_fp8<__hip_fp8_e5m2_fnuz, __hip_fp8_e5m2> _te_hip_fp8_e5m2;
#elif HIP_FP8_TYPE_FNUZ
typedef __hip_fp8_e4m3_fnuz _te_hip_fp8_e4m3;
typedef __hip_fp8_e5m2_fnuz _te_hip_fp8_e5m2;
static inline bool te_fp8_fnuz() { return true; }
#elif HIP_FP8_TYPE_OCP
typedef __hip_fp8_e4m3 _te_hip_fp8_e4m3;
typedef __hip_fp8_e5m2 _te_hip_fp8_e5m2;
static inline bool te_fp8_fnuz() { return false; }
#else
#error "Unsupported HIP_FP8_TYPE"
#endif //__HIP_DEVICE_COMPILE__
#else //HIP_VERSION >= 60300000
typedef __hip_fp8_e4m3_fnuz _te_hip_fp8_e4m3;
typedef __hip_fp8_e5m2_fnuz _te_hip_fp8_e5m2;
#endif //HIP_VERSION >= 60300000
struct te_hip_fp8_e4m3 {
_te_hip_fp8_e4m3 data;
__host__ __device__ te_hip_fp8_e4m3() = default;
__host__ __device__ operator float() const { return data.operator float(); }
__host__ __device__ te_hip_fp8_e4m3(const float& v) { data = v;}
};
static_assert(sizeof(te_hip_fp8_e4m3) == 1, "Size mismatch");
union te_hip_fp8_e5m2 {
_te_hip_fp8_e5m2 data;
__host__ __device__ te_hip_fp8_e5m2() = default;
__host__ __device__ operator float() const { return data.operator float(); }
__host__ __device__ te_hip_fp8_e5m2(const float& v) { data = v; }
};
static_assert(sizeof(te_hip_fp8_e5m2) == 1, "Size mismatch");
#else //HIP_VERSION >= 60200000
// FP8 header version 0.3, 2021/05/11
#define HIP_HOST_DEVICE __host__ __device__ #define HIP_HOST_DEVICE __host__ __device__
#define HIP_DEVICE __device__ #define HIP_DEVICE __device__
#define HIP_HOST __host__ #define HIP_HOST __host__
...@@ -69,7 +167,6 @@ static inline __host__ bool get_hip_f8_bias_mode() { ...@@ -69,7 +167,6 @@ static inline __host__ bool get_hip_f8_bias_mode() {
} }
#endif // __HIPCC_RTC__ #endif // __HIPCC_RTC__
#ifdef __HIPCC__
static __device__ bool hip_f8_bias_mode_bit_device = true; static __device__ bool hip_f8_bias_mode_bit_device = true;
static inline __device__ bool get_hip_f8_bias_mode() { static inline __device__ bool get_hip_f8_bias_mode() {
...@@ -91,7 +188,6 @@ static void set_hip_f8_bias_mode_optimal() { ...@@ -91,7 +188,6 @@ static void set_hip_f8_bias_mode_optimal() {
hip_f8_bias_mode_bit_host = true; hip_f8_bias_mode_bit_host = true;
} }
#endif // __HIPCC_RTC__ #endif // __HIPCC_RTC__
#endif // __HIPCC__
template<hip_f8_type T> template<hip_f8_type T>
...@@ -376,7 +472,6 @@ struct hip_f8 { ...@@ -376,7 +472,6 @@ struct hip_f8 {
} }
}; };
#ifdef __HIPCC__
template<hip_f8_type T> template<hip_f8_type T>
struct hip_f8x4 { struct hip_f8x4 {
...@@ -455,4 +550,13 @@ __device__ hip_float32x4 mfma_f32_16x16x32(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip ...@@ -455,4 +550,13 @@ __device__ hip_float32x4 mfma_f32_16x16x32(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip
template<hip_f8_type T_A, hip_f8_type T_B> template<hip_f8_type T_A, hip_f8_type T_B>
__device__ hip_float32x16 mfma_f32_32x32x16(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x16 c); __device__ hip_float32x16 mfma_f32_32x32x16(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x16 c);
typedef hip_f8<hip_f8_type::fp8> te_hip_fp8_e4m3;
typedef hip_f8<hip_f8_type::bf8> te_hip_fp8_e5m2;
#endif //HIP_VERSION >= 60200000
#else //__HIPCC__
typedef struct {char storage;} te_hip_fp8_e4m3;
typedef struct {char storage;} te_hip_fp8_e5m2;
#endif //__HIPCC__ #endif //__HIPCC__
\ No newline at end of file
...@@ -224,8 +224,8 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -224,8 +224,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using bf16 = hip_bfloat16; using bf16 = hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0; using fp8e8m0 = __nv_fp8_e8m0;
...@@ -248,8 +248,8 @@ TRANSFORMER_ENGINE_TYPE_NAME(float) ...@@ -248,8 +248,8 @@ TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(half)
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::fp8>) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::bf8>) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e5m2)
#else #else
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
......
...@@ -36,30 +36,26 @@ namespace { ...@@ -36,30 +36,26 @@ namespace {
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000 static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
typedef hipDataType hipblasltDatatype_t;
typedef hipblasComputeType_t hipblasLtComputeType_t;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return HIPBLASLT_R_16F; return HIP_R_16F;
case DType::kFloat32: case DType::kFloat32:
return HIPBLASLT_R_32F; return HIP_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIPBLASLT_R_16B; return HIP_R_16BF;
#if HIP_VERSION >= 60300000
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return HIPBLASLT_R_8F_E4M3; return te_fp8_fnuz() ? HIP_R_8F_E4M3_FNUZ : HIP_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return HIPBLASLT_R_8F_E5M2; return te_fp8_fnuz() ? HIP_R_8F_E5M2_FNUZ: HIP_R_8F_E5M2;
#else
case DType::kFloat8E4M3:
return HIP_R_8F_E4M3_FNUZ;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2_FNUZ;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -367,11 +363,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool ...@@ -367,11 +363,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) ); NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) ); NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
} }
...@@ -575,11 +567,11 @@ public: ...@@ -575,11 +567,11 @@ public:
const std::string_view &getName(const T &val) { const std::string_view &getName(const T &val) {
return map.at(val); return map.at(val);
} }
T getValue(const std::string& name, const char *label="") T getValue(const std::string& name, const char *label="", std::function<bool(const T&)> filter = nullptr)
{ {
for (auto iter = map.begin(); iter != map.end(); ++iter) for (auto iter = map.begin(); iter != map.end(); ++iter)
{ {
if (name == iter->second) return iter->first; if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
} }
NVTE_ERROR("Invalid ", label, " name: ", name); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
...@@ -587,14 +579,18 @@ protected: ...@@ -587,14 +579,18 @@ protected:
const std::unordered_map<T, std::string_view> &map; const std::unordered_map<T, std::string_view> &map;
}; };
static std::unordered_map<hipblasltDatatype_t, std::string_view> type_name_map = { static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIPBLASLT_R_32F, "float32"}, {HIP_R_32F, "float32"},
{HIPBLASLT_R_16F, "float16"}, {HIP_R_16F, "float16"},
{HIPBLASLT_R_16B, "bfloat16"}, {HIP_R_16BF, "bfloat16"},
{HIPBLASLT_R_8F_E4M3, "float8e4m3"}, {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
{HIPBLASLT_R_8F_E5M2, "float8e5m2"}, {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
#if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"},
#endif
}; };
static NameMapper<hipblasltDatatype_t> typeNameMapper(type_name_map); static NameMapper<hipDataType> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_N, "N"},
...@@ -613,24 +609,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = ...@@ -613,24 +609,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
}; };
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map); static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasLtComputeType_t, std::string_view> comp_name_map = { static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
{HIPBLASLT_COMPUTE_F32, "f32"} {HIPBLAS_COMPUTE_32F, "f32"}
}; };
static NameMapper<hipblasLtComputeType_t> computeNameMapper(comp_name_map); static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache { static class GemmAlgoCache {
public: public:
struct Key { struct Key {
int deviceCap; int deviceCap;
hipblasltDatatype_t a_type, b_type, d_type, bias_type; hipDataType a_type, b_type, d_type, bias_type;
int m, n, k; int m, n, k;
int lda, ldb, ldd; int lda, ldb, ldd;
hipblasOperation_t transa, transb; hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue; hipblasLtEpilogue_t epilogue;
Key(int deviceCap_, Key(int deviceCap_,
hipblasltDatatype_t a_type_, hipblasltDatatype_t b_type_, hipDataType a_type_, hipDataType b_type_,
hipblasltDatatype_t d_type_, hipblasltDatatype_t bias_type_, hipDataType d_type_, hipDataType bias_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasOperation_t transa_, hipblasOperation_t transb_,
hipblasLtEpilogue_t epilogue_): hipblasLtEpilogue_t epilogue_):
...@@ -864,18 +860,32 @@ protected: ...@@ -864,18 +860,32 @@ protected:
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue; continue;
} }
cfg.a_type = typeNameMapper.getValue(type_a, "type_a"); #if HIP_VERSION >= 60300000
cfg.b_type = typeNameMapper.getValue(type_b, "type_b"); auto fp8_filter = te_fp8_fnuz()
cfg.d_type = typeNameMapper.getValue(type_d, "type_d"); ? [](const hipDataType& val)
cfg.bias_type = (bias_type == "-") ? (hipblasltDatatype_t)-1 : typeNameMapper.getValue(bias_type, "bias_type"); { return (val != HIP_R_8F_E4M3 && val != HIP_R_8F_E5M2); }
: [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
};
#else
auto fp8_filter = nullptr;
#endif
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
cfg.bias_type = (bias_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types //Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLASLT_COMPUTE_F32 || typeNameMapper.getValue(scale, "scale") != HIPBLASLT_R_32F) if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F)
{ {
continue; continue;
} }
...@@ -958,9 +968,9 @@ protected: ...@@ -958,9 +968,9 @@ protected:
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb) << transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipblasltDatatype_t)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLASLT_COMPUTE_F32) << typeNameMapper.getName(HIPBLASLT_R_32F) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
} }
...@@ -1026,10 +1036,10 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1026,10 +1036,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
const hipblasltDatatype_t A_type = get_hipblaslt_dtype(inputA->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipblasltDatatype_t B_type = get_hipblaslt_dtype(inputB->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipblasltDatatype_t D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipblasltDatatype_t bias_type = get_hipblaslt_dtype(inputBias->data.dtype); const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
...@@ -1063,7 +1073,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1063,7 +1073,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t ld_gelumat = (int64_t) ldd; int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
hipblasLtComputeType_t gemm_compute_type = HIPBLASLT_COMPUTE_F32; hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
...@@ -1076,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1076,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb)); ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIPBLASLT_R_32F)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
...@@ -1153,7 +1163,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1153,7 +1163,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipblasltDatatype_t)-1, use_fp8 ? bias_type : (hipDataType)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue ); m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo; GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
...@@ -1468,11 +1478,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1468,11 +1478,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) ); NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
D_temp = D; D_temp = D;
...@@ -1565,11 +1571,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1565,11 +1571,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1595,11 +1597,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1595,11 +1597,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
...@@ -1647,11 +1645,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1647,11 +1645,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) ); NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1678,11 +1672,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1678,11 +1672,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) { if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
...@@ -1783,11 +1773,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1783,11 +1773,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(D_temp) ); NVTE_CHECK_CUDA( hipFree(D_temp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
} }
......
...@@ -36,8 +36,8 @@ using MATH_T = float; ...@@ -36,8 +36,8 @@ using MATH_T = float;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
using transformer_engine::DType; using transformer_engine::DType;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment