"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "1c0aabdeb0cf77019a1f89b5bed5b8eebdd5c211"
Unverified Commit 4b7869d6 authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[ROCm] Add gfx1102/gfx1103 support (#40037)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
parent 4a79262e
...@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) ...@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201") set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
# ROCm installation prefix. Default to /opt/rocm but allow override via # ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake. # -DROCM_PATH=/your/rocm/path when invoking cmake.
......
...@@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz; ...@@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#define __HIP__FP8MFMA__ #define __HIP__FP8MFMA__
#endif #endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1150__) || defined(__gfx1151__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG) #if defined(NDEBUG)
#undef NDEBUG #undef NDEBUG
#include <assert.h> #include <assert.h>
...@@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
} }
} }
#elif defined(__HIP__GFX11__) #elif defined(__GFX11__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
...@@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( ...@@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
out_ptr[threadIdx.x] = from_float<scalar_t>(acc); out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
} }
#elif defined(__HIP__GFX12__) #elif defined(__GFX12__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
......
...@@ -26,16 +26,11 @@ ...@@ -26,16 +26,11 @@
#define __HIP__GFX9__ #define __HIP__GFX9__
#endif #endif
#if defined(__HIPCC__) && \ // Combined RDNA macro (gfx11 + gfx12) - both use 32-wide wavefronts
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \ #if defined(__GFX11__) || defined(__GFX12__)
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX1X__ #define __HIP__GFX1X__
#endif #endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) #if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__ #define __HIP__MI3XX__
#endif #endif
...@@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
return out_c; return out_c;
} }
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) #if defined(__HIP__MI3XX__) || defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
...@@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B; float sB = *s_B;
while (m < M) { while (m < M) {
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {}; float sum[N][YTILE] = {};
#else #else
...@@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) { for (int y = 0; y < YTILE; ++y) {
#pragma unroll #pragma unroll
...@@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
// Final reduction // Final reduction
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
...@@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif #endif
const bool writeback_lane = const bool writeback_lane =
#ifdef __HIP__GFX12__ #ifdef __GFX12__
threadIdx.x == (THRDS - 1); threadIdx.x == (THRDS - 1);
#else #else
threadIdx.x == 0; threadIdx.x == 0;
...@@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault. if (y + m >= M) break; // To avoid mem access fault.
#ifdef __HIP__GFX12__ #ifdef __GFX12__
float result = sum[n][y] * sA * sB; float result = sum[n][y] * sA * sB;
#else #else
float result = sum[n][y][0] * sA * sB; float result = sum[n][y][0] * sA * sB;
...@@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE; m += CuCount * _WvPrGrp * YTILE;
} }
} }
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) #else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
...@@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, ...@@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) #endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) #if defined(__HIP__MI3XX__) || defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
...@@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B; float sB = *s_B;
while (m < M) { while (m < M) {
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {}; float sum[N][YTILE] = {};
#else #else
...@@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll #pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) { for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) { for (int y = 0; y < YTILE; ++y) {
#pragma unroll #pragma unroll
...@@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
} }
// Final reduction // Final reduction
#ifdef __HIP__GFX12__ #ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
...@@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif #endif
const bool writeback_lane = const bool writeback_lane =
#ifdef __HIP__GFX12__ #ifdef __GFX12__
threadIdx.x == (THRDS - 1); threadIdx.x == (THRDS - 1);
#else #else
threadIdx.x == 0; threadIdx.x == 0;
...@@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault. if (y + m >= M) break; // To avoid mem access fault.
#ifdef __HIP__GFX12__ #ifdef __GFX12__
float result = sum[n][y] * sA * sB; float result = sum[n][y] * sA * sB;
#else #else
float result = sum[n][y][0] * sA * sB; float result = sum[n][y][0] * sA * sB;
...@@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ...@@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE; m += CuCount * _WvPrGrp * YTILE;
} }
} }
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) #else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
...@@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, ...@@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int CuCount) { const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) #endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c, const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment