Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e109e598
Unverified
Commit
e109e598
authored
Feb 22, 2025
by
Kaixi Hou
Committed by
GitHub
Feb 22, 2025
Browse files
[NVIDIA] Support nvfp4 cutlass gemm (#13571)
parent
8db1b9d0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
494 additions
and
1 deletion
+494
-1
CMakeLists.txt
CMakeLists.txt
+3
-1
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
+37
-0
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+280
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
tests/kernels/test_nvfp4_scaled_mm.py
tests/kernels/test_nvfp4_scaled_mm.py
+150
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-0
No files found.
CMakeLists.txt
View file @
e109e598
...
@@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Please keep this in sync with FetchContent_Declare line below.
# Please keep this in sync with FetchContent_Declare line below.
set
(
CUTLASS_REVISION
"v3.
7
.0"
CACHE STRING
"CUTLASS revision to use"
)
set
(
CUTLASS_REVISION
"v3.
8
.0"
CACHE STRING
"CUTLASS revision to use"
)
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if
(
DEFINED ENV{VLLM_CUTLASS_SRC_DIR}
)
if
(
DEFINED ENV{VLLM_CUTLASS_SRC_DIR}
)
...
@@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/permute_cols.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
)
"csrc/cutlass_extensions/common.cpp"
)
...
@@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
set
(
SRCS
set
(
SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
)
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
SRCS
"
${
SRCS
}
"
...
...
csrc/ops.h
View file @
e109e598
...
@@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
...
@@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t
row
);
int64_t
row
);
#ifndef USE_ROCM
#ifndef USE_ROCM
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
0 → 100644
View file @
e109e598
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
cutlass_scaled_fp4_mm_sm100a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
#endif
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
cutlass_scaled_fp4_mm_sm100a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel, vLLM should "
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above."
);
}
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
0 → 100644
View file @
e109e598
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template
<
typename
T
>
struct
KernelTraits
;
template
<
>
struct
KernelTraits
<
float
>
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
half_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
bfloat16_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
};
template
<
typename
T
>
struct
Fp4GemmSm100
{
// A matrix configuration
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
32
;
// B matrix configuration
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
32
;
// C/D matrix configuration
using
ElementD
=
T
;
using
ElementC
=
T
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Kernel functional config
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Kernel Perf config
using
MmaTileShape
=
typename
KernelTraits
<
T
>::
MmaTileShape
;
using
ClusterShape
=
typename
KernelTraits
<
T
>::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
KernelTraits
<
T
>::
PerSmTileShape_MNK
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape_MNK
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutATag
,
AlignmentA
,
ElementB
,
LayoutBTag
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
LayoutA
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideA
{}));
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutSFA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
LayoutB
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideB
{}));
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutSFB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutC
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideC
{}));
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
};
template
<
typename
T
>
typename
T
::
Gemm
::
Arguments
args_from_options
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
M
,
int64_t
N
,
int64_t
K
)
{
using
ElementA
=
typename
T
::
Gemm
::
ElementA
;
using
ElementB
=
typename
T
::
Gemm
::
ElementB
;
using
ElementSFA
=
cutlass
::
float_ue4m3_t
;
using
ElementSFB
=
cutlass
::
float_ue4m3_t
;
using
ElementD
=
typename
T
::
Gemm
::
ElementD
;
using
ElementCompute
=
float
;
using
StrideA
=
typename
T
::
StrideA
;
using
StrideB
=
typename
T
::
StrideB
;
using
StrideD
=
typename
T
::
StrideD
;
using
Sm100BlkScaledConfig
=
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm100BlkScaledConfig
;
int
m
=
static_cast
<
int
>
(
M
);
int
n
=
static_cast
<
int
>
(
N
);
int
k
=
static_cast
<
int
>
(
K
);
auto
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
{
m
,
k
,
1
});
auto
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
n
,
k
,
1
});
auto
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
{
m
,
n
,
1
});
auto
layout_SFA
=
Sm100BlkScaledConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
auto
layout_SFB
=
Sm100BlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
typename
T
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
{
// Mainloop arguments
static_cast
<
ElementA
const
*>
(
A
.
data_ptr
()),
stride_A
,
static_cast
<
ElementB
const
*>
(
B
.
data_ptr
()),
stride_B
,
static_cast
<
ElementSFA
const
*>
(
A_sf
.
data_ptr
()),
layout_SFA
,
static_cast
<
ElementSFB
const
*>
(
B_sf
.
data_ptr
()),
layout_SFB
},
{
// Epilogue arguments
{},
// epilogue.thread
static_cast
<
ElementD
const
*>
(
D
.
data_ptr
()),
stride_D
,
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
return
arguments
;
}
template
<
typename
T
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
typename
Fp4GemmSm100
<
T
>::
Gemm
gemm
;
auto
arguments
=
args_from_options
<
Fp4GemmSm100
<
T
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
size_t
workspace_size
=
Fp4GemmSm100
<
T
>::
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
#else
template
<
typename
T
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
TORCH_CHECK
(
false
,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."
);
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
void
cutlass_scaled_fp4_mm_sm100a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
CHECK_INPUT
(
A
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
B
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
A_sf
,
SF_DTYPE
,
"scale_a"
);
CHECK_INPUT
(
B_sf
,
SF_DTYPE
,
"scale_b"
);
CHECK_INPUT
(
alpha
,
at
::
ScalarType
::
Float
,
"alpha"
);
TORCH_CHECK
(
A
.
dim
()
==
2
,
"a must be a matrix"
);
TORCH_CHECK
(
B
.
dim
()
==
2
,
"b must be a matrix"
);
TORCH_CHECK
(
A
.
sizes
()[
1
]
==
B
.
sizes
()[
1
],
"a and b shapes cannot be multiplied ("
,
A
.
sizes
()[
0
],
"x"
,
A
.
sizes
()[
1
],
" and "
,
B
.
sizes
()[
0
],
"x"
,
B
.
sizes
()[
1
],
")"
);
auto
const
m
=
A
.
sizes
()[
0
];
auto
const
n
=
B
.
sizes
()[
0
];
auto
const
k
=
A
.
sizes
()[
1
]
*
2
;
constexpr
int
alignment
=
32
;
TORCH_CHECK
(
k
%
alignment
==
0
,
"Expected k to be divisible by "
,
alignment
,
", but got a shape: ("
,
A
.
sizes
()[
0
],
"x"
,
A
.
sizes
()[
1
],
"), k: "
,
k
,
"."
);
TORCH_CHECK
(
n
%
alignment
==
0
,
"Expected n to be divisible by "
,
alignment
,
", but got b shape: ("
,
B
.
sizes
()[
0
],
"x"
,
B
.
sizes
()[
1
],
")."
);
auto
round_up
=
[](
int
x
,
int
y
)
{
return
(
x
+
y
-
1
)
/
y
*
y
;
};
int
rounded_m
=
round_up
(
m
,
128
);
int
rounded_n
=
round_up
(
n
,
128
);
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
// integer.
int
rounded_k
=
round_up
(
k
/
16
,
4
);
TORCH_CHECK
(
A_sf
.
dim
()
==
2
,
"scale_a must be a matrix"
);
TORCH_CHECK
(
B_sf
.
dim
()
==
2
,
"scale_b must be a matrix"
);
TORCH_CHECK
(
A_sf
.
sizes
()[
1
]
==
B_sf
.
sizes
()[
1
],
"scale_a and scale_b shapes cannot be multiplied ("
,
A_sf
.
sizes
()[
0
],
"x"
,
A_sf
.
sizes
()[
1
],
" and "
,
B_sf
.
sizes
()[
0
],
"x"
,
B_sf
.
sizes
()[
1
],
")"
);
TORCH_CHECK
(
A_sf
.
sizes
()[
0
]
==
rounded_m
&&
A_sf
.
sizes
()[
1
]
==
rounded_k
,
"scale_a must be padded and swizzled to a shape ("
,
rounded_m
,
"x"
,
rounded_k
,
"), but got a shape ("
,
A_sf
.
sizes
()[
0
],
"x"
,
A_sf
.
sizes
()[
1
],
")"
);
TORCH_CHECK
(
B_sf
.
sizes
()[
0
]
==
rounded_n
&&
B_sf
.
sizes
()[
1
]
==
rounded_k
,
"scale_b must be padded and swizzled to a shape ("
,
rounded_n
,
"x"
,
rounded_k
,
"), but got a shape ("
,
B_sf
.
sizes
()[
0
],
"x"
,
B_sf
.
sizes
()[
1
],
")"
);
auto
out_dtype
=
D
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
runGemm
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runGemm
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
runGemm
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
}
}
csrc/torch_bindings.cpp
View file @
e109e598
...
@@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"SymInt size_k) -> Tensor"
);
"SymInt size_k) -> Tensor"
);
// conditionally compiled so impl registration is in source file
// conditionally compiled so impl registration is in source file
// CUTLASS nvfp4 block scaled GEMM
ops
.
def
(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()"
);
ops
.
impl
(
"cutlass_scaled_fp4_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_fp4_mm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
// quantization, as well as bias
ops
.
def
(
ops
.
def
(
...
...
tests/kernels/test_nvfp4_scaled_mm.py
0 → 100644
View file @
e109e598
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# m, n, k
SHAPES
=
[(
128
,
128
,
64
),
(
128
,
128
,
128
),
(
256
,
128
,
64
),
(
128
,
256
,
128
)]
PAD_SHAPES
=
[(
150
,
128
,
64
),
(
128
,
128
,
96
)]
SHAPES
.
extend
(
PAD_SHAPES
)
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
(
int4_value
&
0x8
)
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
(
signBit
):
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
(
a
.
dtype
==
torch
.
uint8
)
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
):
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
(
m_k
==
n_k
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_nvfp4_gemm
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
,
int
],
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
m
,
n
,
packed_k
=
shape
k
=
packed_k
*
2
block_size
=
16
a_dtype
=
torch
.
randn
((
m
,
k
),
dtype
=
dtype
,
device
=
device
)
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
device
)
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
b_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
alpha
=
1.
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
ops
.
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
)
out
=
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
torch
.
testing
.
assert_close
(
out
,
expected_out
.
to
(
dtype
=
dtype
),
atol
=
1e-1
,
rtol
=
1e-1
)
vllm/_custom_ops.py
View file @
e109e598
...
@@ -433,6 +433,18 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
...
@@ -433,6 +433,18 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
# cutlass
# cutlass
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_fp4_mm
(
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
)
return
out
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment