Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
1e85589d
Unverified
Commit
1e85589d
authored
Aug 29, 2025
by
hlu1
Committed by
GitHub
Aug 29, 2025
Browse files
Make fp4_quantize kernels work on sm103 (#9807)
Signed-off-by:
Hao Lu
<
14827759+hlu1@users.noreply.github.com
>
parent
c2a26e72
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
189 additions
and
325 deletions
+189
-325
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
+9
-161
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
+176
-0
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
+4
-164
No files found.
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
View file @
1e85589d
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include <torch/all.h>
template
<
typename
T
>
#include "nvfp4_quant.cuh"
struct
TypeConverter
{
#include "utils.h"
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
...
@@ -720,6 +562,9 @@ void scaled_fp4_experts_quant_sm100a(
...
@@ -720,6 +562,9 @@ void scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
auto
sm_version
=
getSMVersion
();
TORCH_CHECK
(
sm_version
==
100
||
sm_version
==
103
,
"fp4_quant is only supported on sm100a/sm103a"
);
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
...
@@ -801,6 +646,9 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
...
@@ -801,6 +646,9 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
bool
use_silu_and_mul
)
{
bool
use_silu_and_mul
)
{
auto
sm_version
=
getSMVersion
();
TORCH_CHECK
(
sm_version
==
100
||
sm_version
==
103
,
"fp4_quant is only supported on sm100a/sm103a"
);
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
...
...
sgl-kernel/csrc/gemm/nvfp4_quant.cuh
0 → 100644
View file @
1e85589d
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cuda.h>
#include <cuda_fp8.h>
#include <cutlass/arch/config.h>
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires sm100a/sm103a.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires sm100a/sm103a.
#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu
View file @
1e85589d
...
@@ -15,176 +15,13 @@ limitations under the License.
...
@@ -15,176 +15,13 @@ limitations under the License.
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include <torch/all.h>
#include "nvfp4_quant.cuh"
#include "utils.h"
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
...
@@ -364,6 +201,9 @@ inline int getMultiProcessorCount() {
...
@@ -364,6 +201,9 @@ inline int getMultiProcessorCount() {
void
scaled_fp4_quant_sm100a
(
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
auto
sm_version
=
getSMVersion
();
TORCH_CHECK
(
sm_version
==
100
||
sm_version
==
103
,
"fp4_quant is only supported on sm100a/sm103a"
);
int32_t
m
=
input
.
size
(
0
);
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
int32_t
n
=
input
.
size
(
1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment