Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90eeea8f
Unverified
Commit
90eeea8f
authored
Jul 24, 2025
by
Gregory Shtrasberg
Committed by
GitHub
Jul 24, 2025
Browse files
[Bugfix][ROCm] Fix for warp_size uses on host (#21205)
Signed-off-by:
Gregory Shtrasberg
<
Gregory.Shtrasberg@amd.com
>
parent
dde295a9
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
67 additions
and
31 deletions
+67
-31
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+1
-1
csrc/attention/paged_attention_v1.cu
csrc/attention/paged_attention_v1.cu
+2
-3
csrc/attention/paged_attention_v2.cu
csrc/attention/paged_attention_v2.cu
+2
-3
csrc/cuda_compat.h
csrc/cuda_compat.h
+29
-2
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+29
-18
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+1
-1
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+1
-1
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+1
-1
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+1
-1
No files found.
csrc/attention/attention_kernels.cuh
View file @
90eeea8f
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
...
csrc/attention/paged_attention_v1.cu
View file @
90eeea8f
...
@@ -16,9 +16,8 @@
...
@@ -16,9 +16,8 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "attention_kernels.cuh"
#include "attention_kernels.cuh"
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
@@ -75,7 +74,7 @@ void paged_attention_v1_launcher(
...
@@ -75,7 +74,7 @@ void paged_attention_v1_launcher(
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
expr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_seq_len
=
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_seq_len
*
sizeof
(
float
);
int
logits_size
=
padded_max_seq_len
*
sizeof
(
float
);
...
...
csrc/attention/paged_attention_v2.cu
View file @
90eeea8f
...
@@ -16,9 +16,8 @@
...
@@ -16,9 +16,8 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include "attention_kernels.cuh"
#include "attention_kernels.cuh"
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
@@ -79,7 +78,7 @@ void paged_attention_v2_launcher(
...
@@ -79,7 +78,7 @@ void paged_attention_v2_launcher(
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
expr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
...
...
csrc/cuda_compat.h
View file @
90eeea8f
...
@@ -4,8 +4,35 @@
...
@@ -4,8 +4,35 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#endif
#endif
#if defined(USE_ROCM) && defined(__GFX9__)
#ifdef USE_ROCM
#define WARP_SIZE 64
struct
Utils
{
static
__host__
int
get_warp_size
()
{
static
bool
is_cached
=
false
;
static
int
result
;
if
(
!
is_cached
)
{
int
device_id
;
cudaDeviceProp
deviceProp
;
cudaGetDevice
(
&
device_id
);
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
);
result
=
deviceProp
.
warpSize
;
is_cached
=
true
;
}
return
result
;
}
static
__device__
constexpr
int
get_warp_size
()
{
#ifdef __GFX9__
return
64
;
#else
return
32
;
#endif
}
};
#define WARP_SIZE Utils::get_warp_size()
#else
#else
#define WARP_SIZE 32
#define WARP_SIZE 32
#endif
#endif
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
90eeea8f
...
@@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
...
@@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
2) This implementation assumes k is small, but will work for any k.
*/
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
typename
IndType
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
IndType
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
_PARAM
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
{
...
@@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
...
@@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
// Restrictions based on previous section.
// Restrictions based on previous section.
static_assert
(
VPT
%
ELTS_PER_LDG
==
0
,
"The elements per thread must be a multiple of the elements per ldg"
);
static_assert
(
VPT
%
ELTS_PER_LDG
==
0
,
"The elements per thread must be a multiple of the elements per ldg"
);
static_assert
(
WARP_SIZE
%
THREADS_PER_ROW
==
0
,
"The threads per row must cleanly divide the threads per warp"
);
static_assert
(
WARP_SIZE
_PARAM
%
THREADS_PER_ROW
==
0
,
"The threads per row must cleanly divide the threads per warp"
);
static_assert
(
THREADS_PER_ROW
==
(
THREADS_PER_ROW
&
-
THREADS_PER_ROW
),
"THREADS_PER_ROW must be power of 2"
);
static_assert
(
THREADS_PER_ROW
==
(
THREADS_PER_ROW
&
-
THREADS_PER_ROW
),
"THREADS_PER_ROW must be power of 2"
);
static_assert
(
THREADS_PER_ROW
<=
WARP_SIZE
,
"THREADS_PER_ROW can be at most warp size"
);
static_assert
(
THREADS_PER_ROW
<=
WARP_SIZE
_PARAM
,
"THREADS_PER_ROW can be at most warp size"
);
// We have NUM_EXPERTS elements per row. We specialize for small #experts
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static
constexpr
int
ELTS_PER_WARP
=
WARP_SIZE
*
VPT
;
static
constexpr
int
ELTS_PER_WARP
=
WARP_SIZE
_PARAM
*
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
ELTS_PER_WARP
/
ELTS_PER_ROW
;
static
constexpr
int
ROWS_PER_WARP
=
ELTS_PER_WARP
/
ELTS_PER_ROW
;
static
constexpr
int
ROWS_PER_CTA
=
WARPS_PER_CTA
*
ROWS_PER_WARP
;
static
constexpr
int
ROWS_PER_CTA
=
WARPS_PER_CTA
*
ROWS_PER_WARP
;
...
@@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
...
@@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
namespace
detail
namespace
detail
{
{
// Constructs some constants needed to partition the work across threads at compile time.
// Constructs some constants needed to partition the work across threads at compile time.
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
>
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
>
struct
TopkConstants
struct
TopkConstants
{
{
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
,
""
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
_PARAM
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
_PARAM
)
==
0
,
""
);
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
_PARAM
));
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
static
constexpr
int
THREADS_PER_ROW
=
EXPERTS
/
VPT
;
static
constexpr
int
THREADS_PER_ROW
=
EXPERTS
/
VPT
;
static
const
expr
int
ROWS_PER_WARP
=
WARP_SIZE
/
THREADS_PER_ROW
;
static
const
int
ROWS_PER_WARP
=
WARP_SIZE
_PARAM
/
THREADS_PER_ROW
;
};
};
}
// namespace detail
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
typename
IndType
>
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
int
WARP_SIZE_PARAM
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
>
;
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_TB
);
dim3
block_dim
(
WARP_SIZE
_PARAM
,
WARPS_PER_TB
);
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
);
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
);
}
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
switch (warpSize) { \
case 32: \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
gating_output, nullptr, topk_weights, topk_indices, \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
stream);
break; \
case 64: \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
break; \
default: \
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
}
template
<
typename
IndType
>
template
<
typename
IndType
>
void
topkGatingSoftmaxKernelLauncher
(
void
topkGatingSoftmaxKernelLauncher
(
...
@@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
...
@@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
const
int
topk
,
const
int
topk
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
static
constexpr
int
WARPS_PER_TB
=
4
;
auto
warpSize
=
WARP_SIZE
;
switch
(
num_experts
)
{
switch
(
num_experts
)
{
case
1
:
case
1
:
LAUNCH_SOFTMAX
(
1
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
1
,
WARPS_PER_TB
);
...
...
csrc/quantization/activation_kernels.cu
View file @
90eeea8f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <cmath>
#include <cmath>
#include "core/math.hpp"
#include "core/math.hpp"
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
90eeea8f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <torch/all.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "
../../
cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "ggml-common.h"
#include "ggml-common.h"
...
...
csrc/rocm/attention.cu
View file @
90eeea8f
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_fp8.h>
#include <hip/hip_fp8.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#include <algorithm>
#include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../attention/dtype_fp8.cuh"
...
...
csrc/rocm/skinny_gemms.cu
View file @
90eeea8f
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include <stdexcept>
#include <stdexcept>
#include <algorithm>
#include <algorithm>
#include "cuda_compat.h"
#include "
../
cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment