Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e9f8e423
Unverified
Commit
e9f8e423
authored
Mar 24, 2025
by
Trevor Morris
Committed by
GitHub
Mar 24, 2025
Browse files
Support FP4 gemm (1/2) (#3899)
parent
22c3702e
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1245 additions
and
5 deletions
+1245
-5
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
+29
-0
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
+394
-0
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
+39
-0
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+365
-0
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+11
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+9
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+2
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+71
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+10
-4
sgl-kernel/tests/test_fp4_gemm.py
sgl-kernel/tests/test_fp4_gemm.py
+151
-0
sgl-kernel/tests/test_fp4_quantize.py
sgl-kernel/tests/test_fp4_quantize.py
+164
-0
No files found.
sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
0 → 100644
View file @
e9f8e423
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
}
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
0 → 100644
View file @
e9f8e423
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires sm100a.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires sm100a.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
__nv_fp8_e8m0
tmp
;
tmp
.
__x
=
__nv_cvt_float_to_e8m0
(
SFValue
,
__NV_SATFINITE
,
cudaRoundPosInf
);
SFValue
=
static_cast
<
float
>
(
tmp
);
fp8SFVal
=
tmp
.
__x
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
fp8SFVal
=
tmp
.
__x
;
SFValue
=
static_cast
<
float
>
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
0
];
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numCols
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
invokeFP4Quantization
(
int
m
,
int
n
,
T
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
if
(
useUE8M0
)
{
cvt_fp16_to_fp4
<
T
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
else
{
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
}
// Instantiate the function.
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
half
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
inline
int
getMultiProcessorCount
()
{
static
int
multi_processor_count
=
[]()
{
int
device_id
=
0
;
int
count
=
0
;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS
(
cudaGetDevice
(
&
device_id
));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
device_id
));
return
count
;
// Initialize the static variable
}();
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
16
==
0
,
"The N dimension must be multiple of 16."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
auto
input_sf_ptr
=
static_cast
<
float
const
*>
(
input_sf
.
data_ptr
());
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
// We don't support e8m0 scales at this moment.
bool
useUE8M0
=
false
;
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp4."
);
}
}
}
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu
0 → 100644
View file @
e9f8e423
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
cutlass_scaled_fp4_mm_sm100a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
#endif
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
cutlass_scaled_fp4_mm_sm100a
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 mm kernel."
);
}
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
0 → 100644
View file @
e9f8e423
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template
<
typename
T
>
struct
KernelTraits
;
template
<
>
struct
KernelTraits
<
float
>
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
half_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
bfloat16_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
};
template
<
typename
T
>
struct
Fp4GemmSm100
{
// A matrix configuration
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
32
;
// B matrix configuration
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
32
;
// C/D matrix configuration
using
ElementD
=
T
;
using
ElementC
=
T
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Kernel functional config
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Kernel Perf config
using
MmaTileShape
=
typename
KernelTraits
<
T
>::
MmaTileShape
;
using
ClusterShape
=
typename
KernelTraits
<
T
>::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
KernelTraits
<
T
>::
PerSmTileShape_MNK
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape_MNK
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutATag
,
AlignmentA
,
ElementB
,
LayoutBTag
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
LayoutA
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideA
{}));
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutSFA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
LayoutB
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideB
{}));
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutSFB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutC
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideC
{}));
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
};
template
<
typename
T
>
typename
T
::
Gemm
::
Arguments
args_from_options
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
M
,
int64_t
N
,
int64_t
K
)
{
using
ElementA
=
typename
T
::
Gemm
::
ElementA
;
using
ElementB
=
typename
T
::
Gemm
::
ElementB
;
using
ElementSFA
=
cutlass
::
float_ue4m3_t
;
using
ElementSFB
=
cutlass
::
float_ue4m3_t
;
using
ElementD
=
typename
T
::
Gemm
::
ElementD
;
using
ElementCompute
=
float
;
using
StrideA
=
typename
T
::
StrideA
;
using
StrideB
=
typename
T
::
StrideB
;
using
StrideD
=
typename
T
::
StrideD
;
using
Sm100BlkScaledConfig
=
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm100BlkScaledConfig
;
int
m
=
static_cast
<
int
>
(
M
);
int
n
=
static_cast
<
int
>
(
N
);
int
k
=
static_cast
<
int
>
(
K
);
auto
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
{
m
,
k
,
1
});
auto
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
n
,
k
,
1
});
auto
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
{
m
,
n
,
1
});
auto
layout_SFA
=
Sm100BlkScaledConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
auto
layout_SFB
=
Sm100BlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
typename
T
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
{
// Mainloop arguments
static_cast
<
ElementA
const
*>
(
A
.
data_ptr
()),
stride_A
,
static_cast
<
ElementB
const
*>
(
B
.
data_ptr
()),
stride_B
,
static_cast
<
ElementSFA
const
*>
(
A_sf
.
data_ptr
()),
layout_SFA
,
static_cast
<
ElementSFB
const
*>
(
B_sf
.
data_ptr
()),
layout_SFB
},
{
// Epilogue arguments
{},
// epilogue.thread
static_cast
<
ElementD
const
*>
(
D
.
data_ptr
()),
stride_D
,
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
return
arguments
;
}
template
<
typename
T
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
typename
Fp4GemmSm100
<
T
>::
Gemm
gemm
;
auto
arguments
=
args_from_options
<
Fp4GemmSm100
<
T
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
size_t
workspace_size
=
Fp4GemmSm100
<
T
>::
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
gemm
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
#else
template
<
typename
T
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
TORCH_CHECK
(
false
,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."
);
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
void
cutlass_scaled_fp4_mm_sm100a
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
)
{
CHECK_INPUT
(
A
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
B
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
A_sf
,
SF_DTYPE
,
"scale_a"
);
CHECK_INPUT
(
B_sf
,
SF_DTYPE
,
"scale_b"
);
CHECK_INPUT
(
alpha
,
at
::
ScalarType
::
Float
,
"alpha"
);
TORCH_CHECK
(
A
.
dim
()
==
2
,
"a must be a matrix"
);
TORCH_CHECK
(
B
.
dim
()
==
2
,
"b must be a matrix"
);
TORCH_CHECK
(
A
.
sizes
()[
1
]
==
B
.
sizes
()[
1
],
"a and b shapes cannot be multiplied ("
,
A
.
sizes
()[
0
],
"x"
,
A
.
sizes
()[
1
],
" and "
,
B
.
sizes
()[
0
],
"x"
,
B
.
sizes
()[
1
],
")"
);
auto
const
m
=
A
.
sizes
()[
0
];
auto
const
n
=
B
.
sizes
()[
0
];
auto
const
k
=
A
.
sizes
()[
1
]
*
2
;
constexpr
int
alignment
=
32
;
TORCH_CHECK
(
k
%
alignment
==
0
,
"Expected k to be divisible by "
,
alignment
,
", but got a shape: ("
,
A
.
sizes
()[
0
],
"x"
,
A
.
sizes
()[
1
],
"), k: "
,
k
,
"."
);
TORCH_CHECK
(
n
%
alignment
==
0
,
"Expected n to be divisible by "
,
alignment
,
", but got b shape: ("
,
B
.
sizes
()[
0
],
"x"
,
B
.
sizes
()[
1
],
")."
);
auto
round_up
=
[](
int
x
,
int
y
)
{
return
(
x
+
y
-
1
)
/
y
*
y
;
};
int
rounded_m
=
round_up
(
m
,
128
);
int
rounded_n
=
round_up
(
n
,
128
);
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
// integer.
int
rounded_k
=
round_up
(
k
/
16
,
4
);
TORCH_CHECK
(
A_sf
.
dim
()
==
2
,
"scale_a must be a matrix"
);
TORCH_CHECK
(
B_sf
.
dim
()
==
2
,
"scale_b must be a matrix"
);
TORCH_CHECK
(
A_sf
.
sizes
()[
1
]
==
B_sf
.
sizes
()[
1
],
"scale_a and scale_b shapes cannot be multiplied ("
,
A_sf
.
sizes
()[
0
],
"x"
,
A_sf
.
sizes
()[
1
],
" and "
,
B_sf
.
sizes
()[
0
],
"x"
,
B_sf
.
sizes
()[
1
],
")"
);
TORCH_CHECK
(
A_sf
.
sizes
()[
0
]
==
rounded_m
&&
A_sf
.
sizes
()[
1
]
==
rounded_k
,
"scale_a must be padded and swizzled to a shape ("
,
rounded_m
,
"x"
,
rounded_k
,
"), but got a shape ("
,
A_sf
.
sizes
()[
0
],
"x"
,
A_sf
.
sizes
()[
1
],
")"
);
TORCH_CHECK
(
B_sf
.
sizes
()[
0
]
==
rounded_n
&&
B_sf
.
sizes
()[
1
]
==
rounded_k
,
"scale_b must be padded and swizzled to a shape ("
,
rounded_n
,
"x"
,
rounded_k
,
"), but got a shape ("
,
B_sf
.
sizes
()[
0
],
"x"
,
B_sf
.
sizes
()[
1
],
")"
);
auto
out_dtype
=
D
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
runGemm
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runGemm
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
runGemm
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
}
}
sgl-kernel/csrc/torch_extension.cc
View file @
e9f8e423
...
@@ -114,6 +114,17 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -114,6 +114,17 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"
);
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"
);
m
.
impl
(
"cublas_grouped_gemm"
,
torch
::
kCUDA
,
&
cublas_grouped_gemm
);
m
.
impl
(
"cublas_grouped_gemm"
,
torch
::
kCUDA
,
&
cublas_grouped_gemm
);
m
.
def
(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()"
);
m
.
impl
(
"cutlass_scaled_fp4_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_fp4_mm
);
m
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()"
);
m
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
/*
/*
* From csrc/moe
* From csrc/moe
*/
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
e9f8e423
...
@@ -113,6 +113,13 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -113,6 +113,13 @@ void apply_rope_pos_ids_cos_sin_cache(
* From csrc/gemm
* From csrc/gemm
*/
*/
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
);
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
torch
::
Tensor
int8_scaled_mm
(
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
mat_b
,
...
@@ -133,6 +140,8 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -133,6 +140,8 @@ torch::Tensor fp8_blockwise_scaled_mm(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
);
const
torch
::
Dtype
&
out_dtype
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
sgl_per_token_group_quant_fp8
(
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_q
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
e9f8e423
...
@@ -26,9 +26,11 @@ from sgl_kernel.gemm import (
...
@@ -26,9 +26,11 @@ from sgl_kernel.gemm import (
awq_dequantize
,
awq_dequantize
,
bmm_fp8
,
bmm_fp8
,
cublas_grouped_gemm
,
cublas_grouped_gemm
,
cutlass_scaled_fp4_mm
,
fp8_blockwise_scaled_mm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
int8_scaled_mm
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_group_quant_int8
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
e9f8e423
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
...
@@ -145,3 +145,73 @@ def sgl_per_token_quant_fp8(
...
@@ -145,3 +145,73 @@ def sgl_per_token_quant_fp8(
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
(
input
,
output_q
,
output_s
)
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
(
input
,
output_q
,
output_s
)
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
sgl_kernels
.
cutlass_scaled_fp4_mm
(
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
)
return
out
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
assert
input
.
ndim
>=
1
,
f
"input.ndim needs to be >= 1, but got
{
input
.
ndim
}
."
other_dims
=
1
if
input
.
ndim
==
1
else
-
1
input
=
input
.
reshape
(
other_dims
,
input
.
shape
[
-
1
])
m
,
n
=
input
.
shape
block_size
=
16
device
=
input
.
device
assert
n
%
block_size
==
0
,
f
"last dim has to be multiple of 16, but got
{
n
}
."
assert
input
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
),
f
"input.dtype needs to be fp16 or bf16 but got
{
input
.
dtype
}
."
# Two fp4 values will be packed into an uint8.
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
# We use the rounded values to store the swizzled values. Then, the scaling
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
rounded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
scale_n
=
n
//
block_size
rounded_n
=
((
scale_n
+
4
-
1
)
//
4
)
*
4
output_scale
=
torch
.
empty
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
sgl_kernels
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
sgl-kernel/setup.py
View file @
e9f8e423
...
@@ -153,6 +153,10 @@ sources = [
...
@@ -153,6 +153,10 @@ sources = [
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/nvfp4_quant_entry.cu"
,
"csrc/gemm/nvfp4_quant_kernels.cu"
,
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
,
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
,
"csrc/gemm/per_token_group_quant_8bit.cu"
,
"csrc/gemm/per_token_group_quant_8bit.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
...
@@ -169,6 +173,7 @@ sources = [
...
@@ -169,6 +173,7 @@ sources = [
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_fp4
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP4"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm100a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM100A"
,
"0"
)
==
"1"
enable_sm100a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM100A"
,
"0"
)
==
"1"
cuda_version
=
_get_cuda_version
()
cuda_version
=
_get_cuda_version
()
...
@@ -180,6 +185,7 @@ if torch.cuda.is_available():
...
@@ -180,6 +185,7 @@ if torch.cuda.is_available():
if
cuda_version
>=
(
12
,
8
)
and
sm_version
>=
100
:
if
cuda_version
>=
(
12
,
8
)
and
sm_version
>=
100
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100,code=sm_100"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100,code=sm_100"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
nvcc_flags
.
append
(
"-DENABLE_NVFP4=1"
)
else
:
else
:
nvcc_flags
.
append
(
"-use_fast_math"
)
nvcc_flags
.
append
(
"-use_fast_math"
)
if
sm_version
>=
90
:
if
sm_version
>=
90
:
...
@@ -188,12 +194,12 @@ if torch.cuda.is_available():
...
@@ -188,12 +194,12 @@ if torch.cuda.is_available():
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
else
:
else
:
# compilation environment without GPU
# compilation environment without GPU
if
enable_sm90a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_sm100a
:
if
enable_sm100a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
else
:
if
enable_sm90a
:
nvcc_flags
.
append
(
"-use_fast_math"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_fp4
:
nvcc_flags
.
append
(
"-DENABLE_NVFP4=1"
)
if
enable_fp8
:
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
enable_bf16
:
if
enable_bf16
:
...
...
sgl-kernel/tests/test_fp4_gemm.py
0 → 100644
View file @
e9f8e423
import
pytest
import
torch
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
,
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# m, n, k
SHAPES
=
[(
128
,
128
,
64
),
(
128
,
128
,
128
),
(
256
,
128
,
64
),
(
128
,
256
,
128
)]
PAD_SHAPES
=
[(
150
,
128
,
64
),
(
128
,
128
,
96
)]
SHAPES
.
extend
(
PAD_SHAPES
)
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
int4_value
&
0x8
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
signBit
:
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
,
):
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
m_k
==
n_k
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
torch
.
inference_mode
()
def
test_nvfp4_gemm
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
],
)
->
None
:
m
,
n
,
packed_k
=
shape
k
=
packed_k
*
2
block_size
=
16
a_dtype
=
torch
.
randn
((
m
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
"cuda"
,
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
torch
.
testing
.
assert_close
(
out
,
expected_out
.
to
(
dtype
=
dtype
),
atol
=
1e-1
,
rtol
=
1e-1
)
sgl-kernel/tests/test_fp4_quantize.py
0 → 100644
View file @
e9f8e423
import
pytest
import
torch
from
sgl_kernel
import
scaled_fp4_quant
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
,
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
SHAPES
=
[(
128
,
64
),
(
128
,
128
),
(
256
,
64
),
(
256
,
128
)]
PAD_SHAPES
=
[
(
90
,
64
),
(
150
,
64
),
(
128
,
48
),
(
128
,
80
),
(
150
,
80
),
(
90
,
48
),
(
90
,
128
),
(
150
,
128
),
(
150
,
48
),
(
90
,
80
),
]
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
# E2M1 to float
# 0111 -> 6
# 0110 -> 4
# 0101 -> 3
# 0100 -> 2
# 0011 -> 1.5
# 0010 -> 1
# 0001 -> 0.5
# 0000 -> 0
E2M1_TO_FLOAT32
=
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
,
]
BLOCK_SIZE
=
16
def
cast_from_fp4
(
x
,
m
,
n
):
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
v_2nd
=
x
&
0xF
v_1st
=
(
x
>>
4
)
&
0xF
c
=
torch
.
stack
((
v_2nd
,
v_1st
),
dim
=-
1
)
out
=
torch
.
tensor
([
E2M1_TO_FLOAT32
[
x
]
for
x
in
c
.
flatten
()])
out
=
out
.
reshape
(
m
,
n
).
to
(
torch
.
float32
)
return
out
def
cast_to_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
x
=
torch
.
abs
(
x
)
x
[(
x
>=
0.0
)
&
(
x
<=
0.25
)]
=
0.0
x
[(
x
>
0.25
)
&
(
x
<
0.75
)]
=
0.5
x
[(
x
>=
0.75
)
&
(
x
<=
1.25
)]
=
1.0
x
[(
x
>
1.25
)
&
(
x
<
1.75
)]
=
1.5
x
[(
x
>=
1.75
)
&
(
x
<=
2.5
)]
=
2.0
x
[(
x
>
2.5
)
&
(
x
<
3.5
)]
=
3.0
x
[(
x
>=
3.5
)
&
(
x
<=
5.0
)]
=
4.0
x
[
x
>
5.0
]
=
6.0
return
x
*
sign
def
get_reciprocal
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
torch
.
where
(
x
==
0
,
torch
.
tensor
(
0.0
,
dtype
=
x
.
dtype
),
1.0
/
x
)
elif
isinstance
(
x
,
(
float
,
int
)):
return
0.0
if
x
==
0
else
1.0
/
x
else
:
raise
TypeError
(
"Input must be a float, int, or a torch.Tensor."
)
def
ref_nvfp4_quant
(
x
,
global_scale
):
assert
global_scale
.
dtype
==
torch
.
float32
assert
x
.
ndim
==
2
m
,
n
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
m
,
n
//
BLOCK_SIZE
,
BLOCK_SIZE
))
vec_max
=
torch
.
max
(
torch
.
abs
(
x
),
dim
=-
1
,
keepdim
=
True
)[
0
].
to
(
torch
.
float32
)
scale
=
global_scale
*
(
vec_max
*
get_reciprocal
(
FLOAT4_E2M1_MAX
))
scale
=
scale
.
to
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
output_scale
=
get_reciprocal
(
scale
*
get_reciprocal
(
global_scale
))
scaled_x
=
x
.
to
(
torch
.
float32
)
*
output_scale
clipped_x
=
torch
.
clamp
(
scaled_x
,
-
6.0
,
6.0
).
reshape
(
m
,
n
)
return
cast_to_fp4
(
clipped_x
),
scale
.
squeeze
(
-
1
)
def
recover_swizzled_scales
(
scale
,
m
,
n
):
rounded_m
=
((
m
+
128
-
1
)
//
128
)
*
128
scale_n
=
n
//
BLOCK_SIZE
rounded_n
=
((
scale_n
+
4
-
1
)
//
4
)
*
4
# Recover the swizzled scaling factor to linear layout
tmp
=
torch
.
reshape
(
scale
,
(
1
,
rounded_m
//
128
,
rounded_n
//
4
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
result
=
torch
.
reshape
(
tmp
,
(
rounded_m
,
rounded_n
)).
to
(
torch
.
float32
)
return
result
[:
m
,
:
scale_n
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
torch
.
inference_mode
()
def
test_quantize_to_fp4
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
],
)
->
None
:
torch
.
manual_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
m
,
n
=
shape
x
=
torch
.
randn
((
m
,
n
),
dtype
=
dtype
)
tensor_amax
=
torch
.
abs
(
x
).
max
().
to
(
torch
.
float32
)
global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
out_ref
,
scale_ref
=
ref_nvfp4_quant
(
x
,
global_scale
)
out
,
out_scale
=
scaled_fp4_quant
(
x
,
global_scale
)
scale_ans
=
recover_swizzled_scales
(
out_scale
,
m
,
n
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
@
pytest
.
mark
.
parametrize
(
"pad_shape"
,
PAD_SHAPES
)
@
torch
.
inference_mode
()
def
test_quantize_to_fp4_padded
(
pad_shape
:
tuple
[
int
,
int
])
->
None
:
torch
.
manual_seed
(
42
)
dtype
=
torch
.
float16
torch
.
set_default_device
(
"cuda:0"
)
m
,
n
=
pad_shape
x
=
torch
.
randn
((
m
,
n
),
dtype
=
dtype
)
tensor_amax
=
torch
.
abs
(
x
).
max
().
to
(
torch
.
float32
)
global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
out_ref
,
scale_ref
=
ref_nvfp4_quant
(
x
,
global_scale
)
out
,
out_scale
=
scaled_fp4_quant
(
x
,
global_scale
)
scale_ans
=
recover_swizzled_scales
(
out_scale
,
m
,
n
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment