Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0da93439
Commit
0da93439
authored
Mar 26, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori
parents
25f2f756
298e5108
Changes
676
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1110 additions
and
137 deletions
+1110
-137
benchmarks/kernels/benchmark_router_gemm.py
benchmarks/kernels/benchmark_router_gemm.py
+134
-0
benchmarks/kernels/cpu/benchmark_cpu_attn.py
benchmarks/kernels/cpu/benchmark_cpu_attn.py
+1
-1
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
+1
-1
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
csrc/cpu/utils.cpp
csrc/cpu/utils.cpp
+4
-1
csrc/libtorch_stable/ops.h
csrc/libtorch_stable/ops.h
+9
-0
csrc/libtorch_stable/permute_cols.cu
csrc/libtorch_stable/permute_cols.cu
+23
-17
csrc/libtorch_stable/torch_bindings.cpp
csrc/libtorch_stable/torch_bindings.cpp
+21
-0
csrc/libtorch_stable/torch_utils.h
csrc/libtorch_stable/torch_utils.h
+13
-0
csrc/moe/gpt_oss_router_gemm.cu
csrc/moe/gpt_oss_router_gemm.cu
+144
-0
csrc/moe/gpt_oss_router_gemm.cuh
csrc/moe/gpt_oss_router_gemm.cuh
+447
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+4
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+6
-0
csrc/ops.h
csrc/ops.h
+2
-2
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+9
-0
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
+15
-13
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+5
-3
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+268
-92
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-5
docker/Dockerfile.rocm_base
docker/Dockerfile.rocm_base
+1
-1
No files found.
benchmarks/kernels/benchmark_router_gemm.py
0 → 100644
View file @
0da93439
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS
=
[
256
,
384
]
DSV3_SUPPORTED_HIDDEN_SIZES
=
[
7168
]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS
=
[
32
,
128
]
GPT_OSS_SUPPORTED_HIDDEN_SIZES
=
[
2880
]
def
get_batch_size_range
(
max_batch_size
):
return
[
2
**
x
for
x
in
range
(
14
)
if
2
**
x
<=
max_batch_size
]
def
get_model_params
(
config
):
if
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
):
num_experts
=
config
.
n_routed_experts
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"GptOssForCausalLM"
,):
num_experts
=
config
.
num_local_experts
hidden_size
=
config
.
hidden_size
else
:
raise
ValueError
(
f
"Unsupported architecture:
{
config
.
architectures
}
"
)
return
num_experts
,
hidden_size
def
get_benchmark
(
model
,
max_batch_size
,
trust_remote_code
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
get_batch_size_range
(
max_batch_size
),
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
,
],
line_names
=
[
"PyTorch"
,
"vLLM"
],
styles
=
([(
"blue"
,
"-"
),
(
"red"
,
"-"
)]),
ylabel
=
"TFLOPs"
,
plot_name
=
f
"
{
model
}
router gemm throughput"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
config
=
get_config
(
model
=
model
,
trust_remote_code
=
trust_remote_code
)
num_experts
,
hidden_size
=
get_model_params
(
config
)
mat_a
=
torch
.
randn
(
(
batch_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
(
(
num_experts
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
bias
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
is_hopper_or_blackwell
=
current_platform
.
is_device_capability
(
90
)
or
current_platform
.
is_device_capability_family
(
100
)
allow_dsv3_router_gemm
=
(
is_hopper_or_blackwell
and
num_experts
in
DSV3_SUPPORTED_NUM_EXPERTS
and
hidden_size
in
DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm
=
(
is_hopper_or_blackwell
and
num_experts
in
GPT_OSS_SUPPORTED_NUM_EXPERTS
and
hidden_size
in
GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias
=
False
if
allow_gpt_oss_router_gemm
:
has_bias
=
True
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
def
runner
():
if
has_bias
:
F
.
linear
(
mat_a
,
mat_b
,
bias
)
else
:
F
.
linear
(
mat_a
,
mat_b
)
elif
provider
==
"vllm"
:
def
runner
():
if
allow_dsv3_router_gemm
:
ops
.
dsv3_router_gemm
(
mat_a
,
mat_b
,
torch
.
bfloat16
)
elif
allow_gpt_oss_router_gemm
:
ops
.
gpt_oss_router_gemm
(
mat_a
,
mat_b
,
bias
)
else
:
raise
ValueError
(
"Unsupported router gemm"
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
batch_size
*
hidden_size
*
num_experts
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"openai/gpt-oss-20b"
)
parser
.
add_argument
(
"--max-batch-size"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
model
,
args
.
max_batch_size
,
args
.
trust_remote_code
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
)
benchmarks/kernels/cpu/benchmark_cpu_attn.py
View file @
0da93439
...
...
@@ -27,7 +27,7 @@ def get_attn_isa(
else
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
:
return
"neon"
elif
torch
.
_C
.
_
cpu
.
_is_amx_tile_supported
():
elif
torch
.
cpu
.
_is_amx_tile_supported
():
return
"amx"
else
:
return
"vec"
...
...
benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py
View file @
0da93439
...
...
@@ -24,7 +24,7 @@ except (ImportError, AttributeError) as e:
sys
.
exit
(
1
)
# ISA selection following test_cpu_fused_moe.py pattern
ISA_CHOICES
=
[
"amx"
,
"vec"
]
if
torch
.
_C
.
_
cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
ISA_CHOICES
=
[
"amx"
,
"vec"
]
if
torch
.
cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
@
torch
.
inference_mode
()
...
...
cmake/external_projects/vllm_flash_attn.cmake
View file @
0da93439
...
...
@@ -39,7 +39,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
1488682bb545f7d020e958a33116b1419d1cfc83
GIT_TAG
29210221863736a08f71a866459e368ad1ac4a95
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
csrc/cpu/utils.cpp
View file @
0da93439
...
...
@@ -173,10 +173,13 @@ ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
void
ScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
void
*
new_ptr
=
std
::
aligned_alloc
(
64
,
new_size
);
TORCH_CHECK
(
new_ptr
!=
nullptr
,
"ScratchPadManager: aligned_alloc failed for size "
,
new_size
);
if
(
ptr_
!=
nullptr
)
{
std
::
free
(
ptr_
);
}
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
)
;
ptr_
=
new_ptr
;
size_
=
new_size
;
}
}
...
...
csrc/libtorch_stable/ops.h
0 → 100644
View file @
0da93439
#pragma once
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#ifndef USE_ROCM
torch
::
stable
::
Tensor
permute_cols
(
torch
::
stable
::
Tensor
const
&
A
,
torch
::
stable
::
Tensor
const
&
perm
);
#endif
csrc/permute_cols.cu
→
csrc/
libtorch_stable/
permute_cols.cu
View file @
0da93439
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <cuda_fp16.h>
#include "torch_utils.h"
static
constexpr
int
default_threads
=
256
;
static
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
...
...
@@ -64,19 +67,22 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
auto
dev
=
A
.
get_device
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
TORCH_CHECK
(
A
.
scalar_type
()
==
at
::
kHalf
||
A
.
scalar_type
()
==
at
::
kBFloat16
,
"Currently only 16bit types are supported"
);
TORCH_CHECK
(
A
.
is_contiguous
(),
"A must be contiguous"
);
TORCH_CHECK
(
A
.
size
(
-
1
)
%
8
==
0
,
"A columns must be a multiple of 8 (128bits)"
);
auto
A_2d
=
A
.
view
({
-
1
,
A
.
size
(
-
1
)});
torch
::
Tensor
D
=
torch
::
empty_like
(
A
);
torch
::
stable
::
Tensor
permute_cols
(
torch
::
stable
::
Tensor
const
&
A
,
torch
::
stable
::
Tensor
const
&
perm
)
{
const
int32_t
dev
=
A
.
get_device_index
();
const
torch
::
stable
::
accelerator
::
DeviceGuard
device_guard
(
dev
);
const
auto
stream
=
get_current_cuda_stream
(
dev
);
STD_TORCH_CHECK
(
A
.
scalar_type
()
==
torch
::
headeronly
::
ScalarType
::
Half
||
A
.
scalar_type
()
==
torch
::
headeronly
::
ScalarType
::
BFloat16
,
"Currently only 16bit types are supported"
);
STD_TORCH_CHECK
(
A
.
is_contiguous
(),
"A must be contiguous"
);
STD_TORCH_CHECK
(
A
.
size
(
-
1
)
%
8
==
0
,
"A columns must be a multiple of 8 (128bits)"
);
auto
A_2d
=
torch
::
stable
::
view
(
A
,
{
-
1
,
A
.
size
(
-
1
)});
torch
::
stable
::
Tensor
D
=
torch
::
stable
::
empty_like
(
A
);
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
block_rows
=
div_ceil
(
A_2d
.
size
(
0
),
sms
);
...
...
csrc/libtorch_stable/torch_bindings.cpp
0 → 100644
View file @
0da93439
#include "ops.h"
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT
(
_C
,
m
)
{
#ifndef USE_ROCM
m
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
#endif
}
STABLE_TORCH_LIBRARY_IMPL
(
_C
,
CUDA
,
m
)
{
#ifndef USE_ROCM
m
.
impl
(
"permute_cols"
,
TORCH_BOX
(
&
permute_cols
));
#endif
}
REGISTER_EXTENSION
(
_C_stable_libtorch
)
csrc/libtorch_stable/torch_utils.h
0 → 100644
View file @
0da93439
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <cuda_runtime.h>
// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline
cudaStream_t
get_current_cuda_stream
(
int32_t
device_index
)
{
void
*
stream_ptr
=
nullptr
;
TORCH_ERROR_CODE_CHECK
(
aoti_torch_get_current_cuda_stream
(
device_index
,
&
stream_ptr
));
return
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
);
}
csrc/moe/gpt_oss_router_gemm.cu
0 → 100644
View file @
0da93439
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void
launch_gpt_oss_router_gemm
(
__nv_bfloat16
*
gA
,
__nv_bfloat16
*
gB
,
__nv_bfloat16
*
gC
,
__nv_bfloat16
*
bias
,
int
batch_size
,
int
output_features
,
int
input_features
,
cudaStream_t
stream
)
{
static
int
const
WARP_TILE_M
=
16
;
static
int
const
TILE_M
=
WARP_TILE_M
;
static
int
const
TILE_N
=
8
;
static
int
const
TILE_K
=
64
;
static
int
const
STAGES
=
16
;
static
int
const
STAGE_UNROLL
=
4
;
static
bool
const
PROFILE
=
false
;
CUtensorMap
weight_map
{};
CUtensorMap
activation_map
{};
constexpr
uint32_t
rank
=
2
;
uint64_t
size
[
rank
]
=
{(
uint64_t
)
input_features
,
(
uint64_t
)
output_features
};
uint64_t
stride
[
rank
-
1
]
=
{
input_features
*
sizeof
(
__nv_bfloat16
)};
uint32_t
box_size
[
rank
]
=
{
TILE_K
,
TILE_M
};
uint32_t
elem_stride
[
rank
]
=
{
1
,
1
};
CUresult
res
=
cuTensorMapEncodeTiled
(
&
weight_map
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
,
rank
,
gB
,
size
,
stride
,
box_size
,
elem_stride
,
CUtensorMapInterleave
::
CU_TENSOR_MAP_INTERLEAVE_NONE
,
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B
,
CUtensorMapL2promotion
::
CU_TENSOR_MAP_L2_PROMOTION_NONE
,
CUtensorMapFloatOOBfill
::
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
TORCH_CHECK
(
res
==
CUDA_SUCCESS
,
"cuTensorMapEncodeTiled failed for weight_map, error code="
,
static_cast
<
int
>
(
res
));
size
[
1
]
=
batch_size
;
box_size
[
1
]
=
TILE_N
;
res
=
cuTensorMapEncodeTiled
(
&
activation_map
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
,
rank
,
gA
,
size
,
stride
,
box_size
,
elem_stride
,
CUtensorMapInterleave
::
CU_TENSOR_MAP_INTERLEAVE_NONE
,
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B
,
CUtensorMapL2promotion
::
CU_TENSOR_MAP_L2_PROMOTION_NONE
,
CUtensorMapFloatOOBfill
::
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
TORCH_CHECK
(
res
==
CUDA_SUCCESS
,
"cuTensorMapEncodeTiled failed for activation_map, error code="
,
static_cast
<
int
>
(
res
));
int
smem_size
=
STAGES
*
STAGE_UNROLL
*
(
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
)
+
TILE_N
*
TILE_K
*
sizeof
(
__nv_bfloat16
));
gpuErrChk
(
cudaFuncSetAttribute
(
gpt_oss_router_gemm_kernel
<
WARP_TILE_M
,
TILE_M
,
TILE_N
,
TILE_K
,
STAGES
,
STAGE_UNROLL
,
PROFILE
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
int
tiles_m
=
(
output_features
+
TILE_M
-
1
)
/
TILE_M
;
int
tiles_n
=
(
batch_size
+
TILE_N
-
1
)
/
TILE_N
;
dim3
grid
(
tiles_m
,
tiles_n
);
dim3
block
(
384
);
cudaLaunchConfig_t
config
;
cudaLaunchAttribute
attrs
[
1
];
config
.
gridDim
=
grid
;
config
.
blockDim
=
block
;
config
.
dynamicSmemBytes
=
smem_size
;
config
.
stream
=
stream
;
config
.
attrs
=
attrs
;
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
config
.
numAttrs
=
1
;
cudaLaunchKernelEx
(
&
config
,
&
gpt_oss_router_gemm_kernel
<
WARP_TILE_M
,
TILE_M
,
TILE_N
,
TILE_K
,
STAGES
,
STAGE_UNROLL
,
PROFILE
>
,
gC
,
gA
,
gB
,
bias
,
output_features
,
batch_size
,
input_features
,
weight_map
,
activation_map
,
nullptr
);
}
void
gpt_oss_router_gemm_cuda_forward
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
)
{
auto
const
batch_size
=
input
.
size
(
0
);
auto
const
input_dim
=
input
.
size
(
1
);
auto
const
output_dim
=
weight
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
launch_gpt_oss_router_gemm
((
__nv_bfloat16
*
)
input
.
data_ptr
(),
(
__nv_bfloat16
*
)
weight
.
data_ptr
(),
(
__nv_bfloat16
*
)
output
.
mutable_data_ptr
(),
(
__nv_bfloat16
*
)
bias
.
data_ptr
(),
batch_size
,
output_dim
,
input_dim
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Unsupported dtype, only supports bfloat16"
);
}
}
void
gpt_oss_router_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
)
{
TORCH_CHECK
(
input
.
dim
()
==
2
,
"input must be 2D"
);
TORCH_CHECK
(
weight
.
dim
()
==
2
,
"weight must be 2D"
);
TORCH_CHECK
(
bias
.
dim
()
==
1
,
"bias must be 1D"
);
TORCH_CHECK
(
input
.
sizes
()[
1
]
==
weight
.
sizes
()[
1
],
"input.size(1) must match weight.size(1)"
);
TORCH_CHECK
(
weight
.
sizes
()[
0
]
==
bias
.
sizes
()[
0
],
"weight.size(0) must match bias.size(0)"
);
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"input tensor must be bfloat16"
);
TORCH_CHECK
(
weight
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"weight tensor must be bfloat16"
);
TORCH_CHECK
(
bias
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"bias tensor must be bfloat16"
);
gpt_oss_router_gemm_cuda_forward
(
output
,
input
,
weight
,
bias
);
}
csrc/moe/gpt_oss_router_gemm.cuh
0 → 100644
View file @
0da93439
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using
barrier
=
cuda
::
barrier
<
cuda
::
thread_scope_block
>
;
namespace
cde
=
cuda
::
device
::
experimental
;
namespace
ptx
=
cuda
::
ptx
;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline
void
gpuAssert
(
cudaError_t
code
,
char
const
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"GPUassert: %s %s %d
\n
"
,
cudaGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
{
throw
std
::
runtime_error
(
cudaGetErrorString
(
code
));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__
uint64_t
gclock64
()
{
unsigned
long
long
int
rv
;
asm
volatile
(
"mov.u64 %0, %%globaltimer;"
:
"=l"
(
rv
));
return
rv
;
}
__device__
void
ldmatrix
(
__nv_bfloat16
rv
[
2
],
uint32_t
smem_ptr
)
{
int
dst
;
asm
volatile
(
"ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
dst
;
}
__device__
void
ldmatrix2
(
__nv_bfloat16
rv
[
4
],
uint32_t
smem_ptr
)
{
int
x
,
y
;
asm
volatile
(
"ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
x
),
"=r"
(
y
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
x
;
rvi
[
1
]
=
y
;
}
__device__
void
ldmatrix4
(
__nv_bfloat16
rv
[
8
],
uint32_t
smem_ptr
)
{
int
x
,
y
,
z
,
w
;
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
x
;
rvi
[
1
]
=
y
;
rvi
[
2
]
=
z
;
rvi
[
3
]
=
w
;
}
__device__
void
HMMA_1688
(
float
d
[
4
],
__nv_bfloat16
a
[
4
],
__nv_bfloat16
b
[
2
],
float
c
[
4
])
{
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
[
0
]);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
[
0
]);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
[
0
]);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
[
0
]);
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
B
[
0
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
}
__device__
void
HMMA_16816
(
float
d
[
4
],
__nv_bfloat16
a
[
8
],
__nv_bfloat16
b
[
4
],
float
c
[
4
])
{
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
[
0
]);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
[
0
]);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
[
0
]);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
[
0
]);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
A
[
2
]),
"r"
(
A
[
3
]),
"r"
(
B
[
0
]),
"r"
(
B
[
1
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
}
__device__
void
bar_wait
(
uint32_t
bar_ptr
,
int
phase
)
{
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;
\n
"
"@P1 bra.uni DONE;
\n
"
"bra.uni LAB_WAIT;
\n
"
"DONE:
\n
"
"}
\n
"
::
"r"
(
bar_ptr
),
"r"
(
phase
));
}
__device__
bool
bar_try_wait
(
uint32_t
bar_ptr
,
int
phase
)
{
uint32_t
success
;
#ifdef INTERNAL
asm
volatile
(
".pragma
\"
set knob DontInsertYield
\"
;
\n
"
:
:
:
"memory"
);
#endif
asm
volatile
(
"{
\n\t
"
".reg .pred P1;
\n\t
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2;
\n\t
"
"selp.b32 %0, 1, 0, P1;
\n\t
"
"}"
:
"=r"
(
success
)
:
"r"
(
bar_ptr
),
"r"
(
phase
));
return
success
;
}
__device__
uint32_t
elect_one_sync
()
{
uint32_t
pred
=
0
;
uint32_t
laneid
=
0
;
asm
volatile
(
"{
\n
"
".reg .b32 %%rx;
\n
"
".reg .pred %%px;
\n
"
" elect.sync %%rx|%%px, %2;
\n
"
"@%%px mov.s32 %1, 1;
\n
"
" mov.s32 %0, %%rx;
\n
"
"}
\n
"
:
"+r"
(
laneid
),
"+r"
(
pred
)
:
"r"
(
0xFFFFFFFF
));
return
pred
;
}
#endif
struct
Profile
{
uint64_t
start
;
uint64_t
weight_load_start
;
uint64_t
act_load_start
;
uint64_t
compute_start
;
uint64_t
complete
;
};
template
<
int
WARP_TILE_M
,
int
TILE_M
,
int
TILE_N
,
int
TILE_K
,
int
STAGES
,
int
STAGE_UNROLL
,
bool
PROFILE
>
__global__
__launch_bounds__
(
384
,
1
)
void
gpt_oss_router_gemm_kernel
(
__nv_bfloat16
*
output
,
__nv_bfloat16
*
weights
,
__nv_bfloat16
*
activations
,
__nv_bfloat16
*
bias
,
int
M
,
int
N
,
int
K
,
const
__grid_constant__
CUtensorMap
weight_map
,
const
__grid_constant__
CUtensorMap
activation_map
,
Profile
*
profile
=
nullptr
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
PROFILE
&&
threadIdx
.
x
==
0
&&
blockIdx
.
y
==
0
)
profile
[
blockIdx
.
x
].
start
=
gclock64
();
extern
__shared__
__align__
(
128
)
char
smem
[];
__nv_bfloat16
*
sh_weights
=
(
__nv_bfloat16
*
)
&
smem
[
0
];
__nv_bfloat16
*
sh_activations
=
(
__nv_bfloat16
*
)
&
smem
[
STAGES
*
STAGE_UNROLL
*
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
barrier
bar_wt_ready
[
STAGES
];
__shared__
barrier
bar_act_ready
[
STAGES
];
__shared__
barrier
bar_data_consumed
[
STAGES
];
__shared__
float4
reduction_buffer
[
128
];
__shared__
nv_bfloat16
sh_bias
[
TILE_M
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
STAGES
;
i
++
)
{
init
(
&
bar_wt_ready
[
i
],
1
);
init
(
&
bar_act_ready
[
i
],
1
);
init
(
&
bar_data_consumed
[
i
],
32
);
}
ptx
::
fence_proxy_async
(
ptx
::
space_shared
);
asm
volatile
(
"prefetch.tensormap [%0];"
:
:
"l"
(
reinterpret_cast
<
uint64_t
>
(
&
weight_map
))
:
"memory"
);
asm
volatile
(
"prefetch.tensormap [%0];"
:
:
"l"
(
reinterpret_cast
<
uint64_t
>
(
&
activation_map
))
:
"memory"
);
}
__syncthreads
();
int
warp_id
=
threadIdx
.
x
/
32
;
int
lane_id
=
threadIdx
.
x
%
32
;
int
phase
=
0
;
int
mib
=
blockIdx
.
x
*
TILE_M
;
int
ni
=
blockIdx
.
y
*
TILE_N
;
float
accum
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
accum
[
i
]
=
0.
f
;
int
const
K_LOOPS_DMA
=
(
K
+
4
*
TILE_K
*
STAGE_UNROLL
-
1
)
/
(
4
*
(
TILE_K
*
STAGE_UNROLL
));
int
const
K_LOOPS_COMPUTE
=
K_LOOPS_DMA
;
// Data loading thread
if
(
warp_id
>=
4
&&
elect_one_sync
())
{
int
stage
=
warp_id
%
4
;
bool
weight_warp
=
warp_id
<
8
;
if
(
!
weight_warp
)
{
cudaGridDependencySynchronize
();
cudaTriggerProgrammaticLaunchCompletion
();
}
for
(
int
ki
=
0
;
ki
<
K_LOOPS_DMA
;
ki
++
)
{
int
k
=
(
ki
*
4
+
(
warp_id
%
4
))
*
TILE_K
*
STAGE_UNROLL
;
uint64_t
desc_ptr_wt
=
reinterpret_cast
<
uint64_t
>
(
&
weight_map
);
uint64_t
desc_ptr_act
=
reinterpret_cast
<
uint64_t
>
(
&
activation_map
);
uint32_t
bar_ptr_wt
=
__cvta_generic_to_shared
(
&
bar_wt_ready
[
stage
]);
uint32_t
bar_ptr_act
=
__cvta_generic_to_shared
(
&
bar_act_ready
[
stage
]);
int
bytes_wt
=
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
);
int
bytes_act
=
TILE_N
*
TILE_K
*
sizeof
(
__nv_bfloat16
);
bar_wait
(
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]),
phase
^
1
);
if
(
weight_warp
)
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
:
"r"
(
bar_ptr_wt
),
"r"
(
STAGE_UNROLL
*
bytes_wt
));
if
(
!
weight_warp
)
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
:
"r"
(
bar_ptr_act
),
"r"
(
STAGE_UNROLL
*
bytes_act
));
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
ki
==
0
&&
weight_warp
)
profile
[
blockIdx
.
x
].
weight_load_start
=
gclock64
();
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
ki
==
0
&&
!
weight_warp
)
profile
[
blockIdx
.
x
].
act_load_start
=
gclock64
();
for
(
int
i
=
0
;
i
<
STAGE_UNROLL
;
i
++
)
{
uint32_t
smem_ptr_wt
=
__cvta_generic_to_shared
(
&
sh_weights
[(
stage
*
STAGE_UNROLL
+
i
)
*
TILE_M
*
TILE_K
]);
uint32_t
crd0
=
k
+
i
*
TILE_K
;
uint32_t
crd1
=
mib
;
if
(
weight_warp
)
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
:
"r"
(
smem_ptr_wt
),
"l"
(
desc_ptr_wt
),
"r"
(
bar_ptr_wt
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
uint32_t
smem_ptr_act
=
__cvta_generic_to_shared
(
&
sh_activations
[(
stage
*
STAGE_UNROLL
+
i
)
*
TILE_N
*
TILE_K
]);
crd0
=
k
+
i
*
TILE_K
;
crd1
=
ni
;
if
(
!
weight_warp
)
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
:
"r"
(
smem_ptr_act
),
"l"
(
desc_ptr_act
),
"r"
(
bar_ptr_act
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
}
stage
+=
4
;
if
(
stage
>=
STAGES
)
{
stage
=
warp_id
%
4
;
phase
^=
1
;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for
(
int
i
=
0
;
i
<
(
STAGES
/
4
)
-
1
;
i
++
)
{
bar_wait
(
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]),
phase
^
1
);
stage
+=
4
;
if
(
stage
>=
STAGES
)
{
stage
=
warp_id
%
4
;
phase
^=
1
;
}
}
}
// Compute threads
else
if
(
warp_id
<
4
)
{
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if
(
threadIdx
.
x
<
TILE_M
)
sh_bias
[
threadIdx
.
x
]
=
bias
[
mib
+
threadIdx
.
x
];
int
stage
=
warp_id
;
int
phase
=
0
;
int
lane_id_div8
=
lane_id
/
8
;
int
lane_id_mod8
=
lane_id
%
8
;
int
lane_row_offset_wt
=
(
lane_id_div8
%
2
)
?
8
:
0
;
int
lane_col_offset_wt
=
(
lane_id_div8
/
2
)
?
1
:
0
;
int
row_wt
=
lane_id_mod8
+
lane_row_offset_wt
;
int
row_act
=
lane_id_mod8
;
int
row_offset_wt
=
(
reinterpret_cast
<
uintptr_t
>
(
sh_weights
)
/
128
)
%
8
;
int
row_offset_act
=
row_offset_wt
;
uint32_t
bar_ptr_wt
=
__cvta_generic_to_shared
(
&
bar_wt_ready
[
stage
]);
uint32_t
bar_ptr_act
=
__cvta_generic_to_shared
(
&
bar_act_ready
[
stage
]);
bool
weight_ready
=
bar_try_wait
(
bar_ptr_wt
,
phase
);
bool
act_ready
=
bar_try_wait
(
bar_ptr_act
,
phase
);
#pragma unroll 2
for
(
int
ki
=
0
;
ki
<
K_LOOPS_COMPUTE
;
ki
++
)
{
int
next_stage
=
stage
+
4
;
int
next_phase
=
phase
;
if
(
next_stage
>=
STAGES
)
{
next_stage
=
warp_id
;
next_phase
^=
1
;
}
while
(
!
weight_ready
||
!
act_ready
)
{
weight_ready
=
bar_try_wait
(
bar_ptr_wt
,
phase
);
act_ready
=
bar_try_wait
(
bar_ptr_act
,
phase
);
}
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
&&
ki
==
0
)
profile
[
blockIdx
.
x
].
compute_start
=
gclock64
();
if
(
ki
+
1
<
K_LOOPS_COMPUTE
)
{
weight_ready
=
bar_try_wait
(
__cvta_generic_to_shared
(
&
bar_wt_ready
[
next_stage
]),
next_phase
);
act_ready
=
bar_try_wait
(
__cvta_generic_to_shared
(
&
bar_act_ready
[
next_stage
]),
next_phase
);
}
#pragma unroll
for
(
int
su
=
0
;
su
<
STAGE_UNROLL
;
su
++
)
{
__nv_bfloat16
*
ptr_weights
=
&
sh_weights
[(
stage
*
STAGE_UNROLL
+
su
)
*
TILE_M
*
TILE_K
];
__nv_bfloat16
*
ptr_act
=
&
sh_activations
[(
stage
*
STAGE_UNROLL
+
su
)
*
TILE_N
*
TILE_K
];
#pragma unroll
for
(
int
kii
=
0
;
kii
<
TILE_K
/
16
;
kii
++
)
{
__nv_bfloat16
a
[
8
];
__nv_bfloat16
b
[
4
];
int
col
=
2
*
kii
+
lane_col_offset_wt
;
int
col_sw
=
((
row_wt
+
row_offset_wt
)
%
8
)
^
col
;
ldmatrix4
(
a
,
__cvta_generic_to_shared
(
&
ptr_weights
[
row_wt
*
TILE_K
+
col_sw
*
8
]));
col
=
2
*
kii
+
lane_id_div8
;
col_sw
=
((
row_act
+
row_offset_act
)
%
8
)
^
col
;
ldmatrix2
(
b
,
__cvta_generic_to_shared
(
&
ptr_act
[
row_act
*
TILE_K
+
8
*
col_sw
]));
HMMA_16816
(
accum
,
a
,
b
,
accum
);
}
}
uint32_t
bar_c
=
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]);
asm
volatile
(
"mbarrier.arrive.shared::cta.b64 _, [%0];"
:
:
"r"
(
bar_c
));
stage
=
next_stage
;
phase
=
next_phase
;
}
float4
accum4
;
accum4
.
x
=
accum
[
0
];
accum4
.
y
=
accum
[
1
];
accum4
.
z
=
accum
[
2
];
accum4
.
w
=
accum
[
3
];
reduction_buffer
[
threadIdx
.
x
]
=
accum4
;
__syncthreads
();
if
(
warp_id
==
0
)
{
int
mi
=
mib
+
warp_id
*
WARP_TILE_M
;
int
tm
=
mi
+
lane_id
/
4
;
int
tn
=
ni
+
2
*
(
lane_id
%
4
);
float4
accum1
=
reduction_buffer
[
32
+
threadIdx
.
x
];
float4
accum2
=
reduction_buffer
[
64
+
threadIdx
.
x
];
float4
accum3
=
reduction_buffer
[
96
+
threadIdx
.
x
];
accum
[
0
]
=
accum
[
0
]
+
accum1
.
x
+
accum2
.
x
+
accum3
.
x
;
accum
[
1
]
=
accum
[
1
]
+
accum1
.
y
+
accum2
.
y
+
accum3
.
y
;
accum
[
2
]
=
accum
[
2
]
+
accum1
.
z
+
accum2
.
z
+
accum3
.
z
;
accum
[
3
]
=
accum
[
3
]
+
accum1
.
w
+
accum2
.
w
+
accum3
.
w
;
float
bias_lo
=
__bfloat162float
(
sh_bias
[
tm
-
mib
]);
float
bias_hi
=
__bfloat162float
(
sh_bias
[
tm
+
8
-
mib
]);
if
(
tn
<
N
&&
tm
<
M
)
output
[
tn
*
M
+
tm
]
=
__float2bfloat16
(
accum
[
0
]
+
bias_lo
);
if
(
tn
+
1
<
N
&&
tm
<
M
)
output
[(
tn
+
1
)
*
M
+
tm
]
=
__float2bfloat16
(
accum
[
1
]
+
bias_lo
);
if
(
tn
<
N
&&
tm
+
8
<
M
)
output
[
tn
*
M
+
tm
+
8
]
=
__float2bfloat16
(
accum
[
2
]
+
bias_hi
);
if
(
tn
+
1
<
N
&&
tm
+
8
<
M
)
output
[(
tn
+
1
)
*
M
+
tm
+
8
]
=
__float2bfloat16
(
accum
[
3
]
+
bias_hi
);
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
profile
[
blockIdx
.
x
].
complete
=
gclock64
();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
csrc/moe/moe_ops.h
View file @
0da93439
...
...
@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
);
// gpt-oss optimized router GEMM kernel for SM90+
void
gpt_oss_router_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
);
#endif
csrc/moe/torch_bindings.cpp
View file @
0da93439
...
...
@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
m
.
def
(
"dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
// conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m
.
def
(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()"
);
m
.
impl
(
"gpt_oss_router_gemm"
,
torch
::
kCUDA
,
&
gpt_oss_router_gemm
);
#endif
}
...
...
csrc/ops.h
View file @
0da93439
...
...
@@ -201,7 +201,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
#endif
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
...
...
@@ -262,7 +261,8 @@ void get_cutlass_moe_mm_data(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
);
void
get_cutlass_moe_mm_problem_sizes_from_expert_offsets
(
const
torch
::
Tensor
&
expert_first_token_offset
,
...
...
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
0da93439
...
...
@@ -300,6 +300,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
"Outer scale stride must be 1 when scales are not transposed"
);
}
int64_t
hidden_size
=
input
.
size
(
-
1
);
TORCH_CHECK
(
hidden_size
>
0
&&
hidden_size
%
group_size
==
0
,
"hidden_size must be a positive multiple of group_size"
);
int64_t
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
num_groups
=
hidden_size
/
group_size
;
TORCH_CHECK
(
scales
.
numel
()
>=
num_tokens
*
num_groups
,
"scales buffer too small: need "
,
num_tokens
*
num_groups
,
" elements, got "
,
scales
.
numel
());
rms_norm_per_block_quant_dispatch
(
out
,
input
,
weight
,
scales
,
group_size
,
var_epsilon
,
scale_ub
,
residual
,
is_scale_transposed
);
...
...
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
View file @
0da93439
...
...
@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
n
,
const
int
k
)
{
const
int
k
,
const
bool
is_gated
)
{
int
expert_id
=
blockIdx
.
x
;
// For gated activations (gate + up), first GEMM output is 2*n.
// For non-gated activations (up only), first GEMM output is n.
int
const
n1
=
is_gated
?
2
*
n
:
n
;
int
occurrences
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
...
...
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int
final_occurrences
=
atomic_buffer
[
expert_id
];
if
constexpr
(
!
SWAP_AB
)
{
problem_sizes1
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
n
1
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
n
;
}
else
{
problem_sizes1
[
expert_id
*
3
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
]
=
n
1
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
k
;
...
...
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
namespace
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
)
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
,
const
bool
is_gated
)
{
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
auto
const
*
topk_ptr
=
topk_ids
.
data_ptr
<
int32_t
>
();
...
...
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
compute_problem_sizes
<
SwapAB
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
static_cast
<
int
>
(
k
)
,
is_gated
);
});
}
}
// namespace
...
...
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
...
...
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
may_swap_ab
);
may_swap_ab
,
is_gated
);
if
(
blockscale_offsets
.
has_value
())
{
// fp4 path
...
...
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
View file @
0da93439
...
...
@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
);
void
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller
(
const
torch
::
Tensor
&
expert_first_token_offset
,
...
...
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
...
...
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
,
blockscale_offsets
);
blockscale_offsets
,
is_gated
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
csrc/rocm/skinny_gemms.cu
View file @
0da93439
This diff is collapsed.
Click to expand it.
csrc/torch_bindings.cpp
View file @
0da93439
...
...
@@ -303,9 +303,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> Tensor"
);
// conditionally compiled so impl registration is in source file
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
ops
.
impl
(
"permute_cols"
,
torch
::
kCUDA
,
&
permute_cols
);
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
ops
.
def
(
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
...
...
@@ -489,8 +486,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets
) ->
"
"()"
);
" int n, int k, Tensor? blockscale_offsets
,
"
"
bool is_gated) ->
()"
);
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
// compute per-expert problem sizes from expert_first_token_offset
...
...
docker/Dockerfile.rocm_base
View file @
0da93439
...
...
@@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev
liblzma-dev pkg-config
\
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
...
...
Prev
1
2
3
4
5
6
7
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment