Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad385667
Commit
ad385667
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.6.3.post1-dev'
parents
be0967c1
903593d3
Changes
364
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5093 additions
and
14 deletions
+5093
-14
csrc/cutlass_extensions/vllm_collective_builder.cuh
csrc/cutlass_extensions/vllm_collective_builder.cuh
+43
-0
csrc/cutlass_extensions/vllm_custom_types.cuh
csrc/cutlass_extensions/vllm_custom_types.cuh
+50
-0
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+49
-0
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
+795
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+19
-14
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+632
-0
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+159
-0
csrc/mamba/causal_conv1d/static_switch.h
csrc/mamba/causal_conv1d/static_switch.h
+28
-0
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+266
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+658
-0
csrc/mamba/mamba_ssm/static_switch.h
csrc/mamba/mamba_ssm/static_switch.h
+28
-0
csrc/moe/marlin_kernels/marlin_moe_kernel.h
csrc/moe/marlin_kernels/marlin_moe_kernel.h
+1616
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
+31
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
+20
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
+31
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
+20
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+31
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+18
-0
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+587
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+12
-0
No files found.
Too many changes to show.
To preserve performance only
364 of 364+
files are displayed.
Plain diff
Email patch
csrc/cutlass_extensions/vllm_collective_builder.cuh
0 → 100644
View file @
ad385667
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
//
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct
CutlassKernelTag
{};
template
<
class
KernelTag
,
class
ArchTag
,
class
OpClass
,
class
ElementA
,
class
GmemLayoutA
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutB
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
class
KernelScheduleType
,
class
Enable
=
void
>
struct
VLLMCollectiveBuilder
{
static_assert
(
sizeof
(
ElementA
)
==
0
,
"Could not build a collective for given parameters."
);
};
template
<
class
ArchTag
,
class
OpClass
,
class
ElementA
,
class
GmemLayoutA
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutB
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
class
KernelScheduleType
>
struct
VLLMCollectiveBuilder
<
CutlassKernelTag
,
ArchTag
,
OpClass
,
ElementA
,
GmemLayoutA
,
AlignmentA
,
ElementB
,
GmemLayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>
{
using
CollectiveOp
=
typename
CollectiveBuilder
<
ArchTag
,
OpClass
,
ElementA
,
GmemLayoutA
,
AlignmentA
,
ElementB
,
GmemLayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>::
CollectiveOp
;
};
};
// namespace cutlass::gemm::collective
\ No newline at end of file
csrc/cutlass_extensions/vllm_custom_types.cuh
0 → 100644
View file @
ad385667
#pragma once
#include "cutlass/integer_subbyte.h"
namespace
cutlass
{
///////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
Bits
,
int
Bias
,
bool
Signed
=
false
>
struct
vllm_biased_integer_subbyte
:
public
integer_subbyte
<
Bits
,
Signed
>
{
using
Base
=
integer_subbyte
<
Bits
,
Signed
>
;
using
Storage
=
typename
Base
::
Storage
;
using
xint_t
=
typename
Base
::
xint_t
;
using
Base
::
bits_mask_
;
using
Base
::
sign_mask_
;
using
Base
::
storage
;
//
// Methods
//
/// No operation
vllm_biased_integer_subbyte
()
=
default
;
/// Conversion from integer type
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
int
value
)
:
Base
(
value
)
{}
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
unsigned
value
)
:
Base
(
value
)
{}
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
double
value
)
:
Base
(
value
)
{}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using
vllm_uint4b8_t
=
vllm_biased_integer_subbyte
<
4
,
8
>
;
// u4b8
using
vllm_uint8b128_t
=
vllm_biased_integer_subbyte
<
8
,
128
>
;
// u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
Bits
,
int
Bias
,
bool
Signed
>
struct
sizeof_bits
<
vllm_biased_integer_subbyte
<
Bits
,
Bias
,
Signed
>>
{
static
constexpr
int
value
=
Bits
;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
0 → 100644
View file @
ad385667
import
enum
from
typing
import
Dict
,
Union
from
cutlass_library
import
*
#
# Extend cutlass library with custom types, and missing values
#
class
VLLMDataType
(
enum
.
Enum
):
u4b8
=
enum_auto
()
u8b128
=
enum_auto
()
class
MixedInputKernelScheduleType
(
enum
.
Enum
):
TmaWarpSpecializedMixedInput
=
enum_auto
()
TmaWarpSpecializedPingpongMixedInput
=
enum_auto
()
TmaWarpSpecializedCooperativeMixedInput
=
enum_auto
()
VLLMDataTypeNames
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
VLLMDataType
.
u8b128
:
"u8b128"
,
}
}
VLLMDataTypeTag
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
VLLMDataType
.
u8b128
:
"cutlass::vllm_uint8b128_t"
,
}
}
VLLMKernelScheduleTag
:
Dict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecializedMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpongMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperativeMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput"
,
}
}
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
0 → 100644
View file @
ad385667
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace
cutlass
{
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template
<
typename
IlvBlkLayout
,
typename
T
,
typename
S
,
int
N
,
FloatRoundStyle
Round
=
FloatRoundStyle
::
round_to_nearest
,
class
Enable
=
void
>
struct
InterleavedNumericArrayConverter
{
using
Converter
=
NumericArrayConverter
<
T
,
S
,
N
,
Round
>
;
using
result_type
=
typename
Converter
::
result_type
;
using
source_type
=
typename
Converter
::
source_type
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
CUTE_INVALID_CONTROL_PATH
(
"InterleavedNumericArrayConverter not implemented
\n
"
);
return
{};
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
template
<
typename
IlvBlkLayout
,
typename
T
,
typename
S
,
int
N
,
FloatRoundStyle
Round
>
struct
InterleavedNumericArrayConverter
<
IlvBlkLayout
,
T
,
S
,
N
,
Round
,
std
::
enable_if_t
<
is_identity_layout
<
IlvBlkLayout
>
()
>>
{
using
Converter
=
NumericArrayConverter
<
T
,
S
,
N
,
Round
>
;
using
result_type
=
typename
Converter
::
result_type
;
using
source_type
=
typename
Converter
::
source_type
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
Converter
::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template
<
typename
RegConvert32bit
,
typename
T
,
typename
S
,
int
N
>
struct
ArrayConverterPacked32Bit
{
using
result_type
=
Array
<
T
,
N
>
;
using
source_type
=
Array
<
S
,
N
>
;
using
result_packed_8_t
=
Array
<
T
,
8
>
;
using
result_packed_4_t
=
Array
<
T
,
4
>
;
using
result_packed_2_t
=
Array
<
T
,
2
>
;
using
src_packed_8_t
=
Array
<
S
,
8
>
;
using
src_packed_4_t
=
Array
<
S
,
4
>
;
using
src_packed_2_t
=
Array
<
S
,
2
>
;
static_assert
(
N
%
2
==
0
,
"N must be a multiple of 2"
);
static_assert
(
cutlass
::
sizeof_bits_v
<
S
>
>=
4
);
// TODO: add 16 packed sources
static_assert
(
32
%
cutlass
::
sizeof_bits_v
<
S
>
==
0
);
static
constexpr
auto
src_elems_per_32bit_reg
=
32
/
cutlass
::
sizeof_bits_v
<
S
>
;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using
ScalarConverter
=
NumericConverter
<
T
,
S
>
;
template
<
typename
PackedSrc
>
CUTLASS_DEVICE
static
uint32_t
to_reg
(
PackedSrc
const
&
source
)
{
if
constexpr
(
sizeof
(
PackedSrc
)
==
1
)
{
return
static_cast
<
uint32_t
>
(
reinterpret_cast
<
const
uint8_t
&>
(
source
));
}
else
if
constexpr
(
sizeof
(
PackedSrc
)
==
2
)
{
return
static_cast
<
uint32_t
>
(
reinterpret_cast
<
const
uint16_t
&>
(
source
));
}
else
{
static_assert
(
sizeof
(
PackedSrc
)
==
4
);
return
reinterpret_cast
<
const
uint32_t
&>
(
source
);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template
<
typename
PackedResultType
,
typename
PackedSrcType
>
CUTLASS_DEVICE
static
PackedResultType
packed_convert
(
PackedSrcType
const
&
source
)
{
static_assert
(
PackedSrcType
::
kElements
==
PackedResultType
::
kElements
);
static_assert
(
PackedResultType
::
kElements
==
2
||
PackedResultType
::
kElements
==
4
||
PackedResultType
::
kElements
==
8
,
"Invalid PackedResultType must be 2, 4 or 8."
);
static_assert
(
std
::
is_same_v
<
typename
PackedSrcType
::
Element
,
S
>
);
static_assert
(
std
::
is_same_v
<
typename
PackedResultType
::
Element
,
T
>
);
return
RegConvert32bit
::
template
convert
<
PackedResultType
>(
to_reg
(
source
));
}
friend
class
detail
::
VectorizedConverter
;
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
using
ConverterType
=
ArrayConverterPacked32Bit
<
RegConvert32bit
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>
;
if
constexpr
(
src_elems_per_32bit_reg
>=
8
)
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_8_t
,
src_packed_8_t
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
else
if
constexpr
(
src_elems_per_32bit_reg
>=
4
)
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
else
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
return
result
;
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
half_t
,
vllm_uint4b8_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t
prmt_indices
[
4
]
=
{
0x4040
,
0x4141
,
0x4242
,
0x4343
};
static_assert
(
RegArray
::
kElements
<=
4
,
"Too many inputs for F16 -> I4 vector converter"
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" prmt.b32 %0, %1, %2, %3;
\n
"
"}
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src
),
"n"
(
0
),
"r"
(
prmt_indices
[
ii
]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
static
constexpr
uint32_t
and_mask
=
0xFFF0FF0F
;
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
^
0xaa
;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
and_mask
),
"n"
(
xor_mask
),
"n"
(
immLut
));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static
constexpr
uint32_t
hfma_bias_rep
=
0xD480E408
;
// {72, 1032}
static
constexpr
uint32_t
hfma_scale_rep
=
0x2C003C00
;
// {1 / 16, 1}
const
half2
&
hfma_bias
=
reinterpret_cast
<
const
half2
&>
(
hfma_bias_rep
);
const
half2
&
hfma_scale
=
reinterpret_cast
<
const
half2
&>
(
hfma_scale_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
]);
fp16x2_val
=
__hfma2
(
hfma_scale
,
fp16x2_val
,
hfma_bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
half_t
,
vllm_uint4b8_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
ii
+=
2
)
{
auto
src_
=
src
>>
(
4
*
(
ii
));
r
[
ii
+
0
]
=
src_
;
r
[
ii
+
1
]
=
src_
;
static
constexpr
uint32_t
and_xor_imm_lut
=
(
0xf0
&
0xcc
)
^
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
static
constexpr
uint32_t
high_nib_mask
=
0x00F000F0
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
1
])
:
"n"
(
high_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static
constexpr
uint32_t
low_nib_bias
=
0x64086408
;
// {1032, 1032}
static
constexpr
uint32_t
high_nib_scale
=
0x2C002C00
;
// {1/16, 1/16}
static
constexpr
uint32_t
high_nib_bias
=
0xD480D480
;
// {-72, -72}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
0
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
low_nib_bias
));
}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
1
]);
fp16x2_val
=
__hfma2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
high_nib_scale
),
reinterpret_cast
<
const
half2
&>
(
high_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
half_t
,
uint4_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
uint4_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
ii
+=
2
)
{
auto
src_
=
src
>>
(
4
*
(
ii
));
r
[
ii
+
0
]
=
src_
;
r
[
ii
+
1
]
=
src_
;
static
constexpr
uint32_t
and_xor_imm_lut
=
(
0xf0
&
0xcc
)
^
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
static
constexpr
uint32_t
high_nib_mask
=
0x00F000F0
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
1
])
:
"n"
(
high_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static
constexpr
uint32_t
low_nib_bias
=
0x64006400
;
// {1024, 1024}
static
constexpr
uint32_t
high_nib_scale
=
0x2C002C00
;
// {1/16, 1/16}
static
constexpr
uint32_t
high_nib_bias
=
0xD400D400
;
// {-64, -64}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
0
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
low_nib_bias
));
}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
1
]);
fp16x2_val
=
__hfma2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
high_nib_scale
),
reinterpret_cast
<
const
half2
&>
(
high_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
half_t
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
uint32_t
const
prmt_indices
[
2
]
=
{
0x5150
,
0x5352
};
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src
),
"n"
(
start_byte_for_fp16
),
"r"
(
prmt_indices
[
ii
]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static
constexpr
uint32_t
bias_rep
=
0x64806480
;
const
half2
&
bias
=
reinterpret_cast
<
const
half2
&>
(
bias_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
float
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
float
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
PackedResultType
r
;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t
const
prmt_indices
[
4
]
=
{
0x7650
,
0x7651
,
0x7652
,
0x7653
};
uint32_t
*
result_as_int
=
reinterpret_cast
<
uint32_t
*>
(
&
r
);
for
(
int
ii
=
0
;
ii
<
PackedResultType
::
kElements
;
++
ii
)
{
result_as_int
[
ii
]
=
__byte_perm
(
src
,
0x4B000000
,
prmt_indices
[
ii
]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r
[
ii
]
-=
(
8388608.
f
+
128.
f
);
// fold in -128 bias
}
return
r
;
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
vllm_uint4b8_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src_reg
)
{
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
uint32_t
src_reg_shifted
=
src_reg
>>
4
;
// Below constructs the following temporary:
uint32_t
const
prmt_indices
[
4
]
=
{
0xF4F0
,
0xF5F1
,
0xF6F2
,
0xF7F3
};
static_assert
(
RegArray
::
kElements
<=
4
,
"Too many inputs for uint4b8_t -> BF16 vector converter"
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" prmt.b32 %0, %1, %2, %3;
\n
"
"}
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src_reg
),
"r"
(
src_reg_shifted
),
"r"
(
prmt_indices
[
ii
]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static
constexpr
uint32_t
xor_mask
=
0x43004300
;
static
constexpr
uint32_t
and_mask
=
0x000F000F
;
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
^
0xaa
;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
and_mask
),
"n"
(
xor_mask
),
"n"
(
immLut
));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static
constexpr
uint32_t
bias_rep
=
0x43084308
;
const
__nv_bfloat162
&
bias
=
reinterpret_cast
<
const
__nv_bfloat162
&>
(
bias_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
__nv_bfloat162
&
bf16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
bf16x2_val
=
__hsub2
(
bf16x2_val
,
bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
}
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
bfloat16_t
,
vllm_uint4b8_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
or_mask
=
0x43004300
;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
r
[
ii
]
=
src
>>
(
4
*
ii
);
static
constexpr
uint32_t
and_or_imm_lut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
or_mask
),
"n"
(
and_or_imm_lut
));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static
constexpr
uint32_t
low_nib_bias
=
0x43084308
;
// {136, 136}
{
__nv_bfloat162
&
fp16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
__nv_bfloat162
&>
(
low_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
bfloat16_t
,
uint4_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
uint4_t
,
N
>
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
or_mask
=
0x43004300
;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
r
[
ii
]
=
src
>>
(
4
*
ii
);
static
constexpr
uint32_t
and_or_imm_lut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
low_nib_mask
),
"n"
(
or_mask
),
"n"
(
and_or_imm_lut
));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static
constexpr
uint32_t
low_nib_bias
=
0x43004300
;
// {128, 128}
{
__nv_bfloat162
&
fp16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
__nv_bfloat162
&>
(
low_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
using
result_packed_4_t
=
Array
<
cutlass
::
bfloat16_t
,
4
>
;
using
result_packed_2_t
=
Array
<
cutlass
::
bfloat16_t
,
2
>
;
using
src_packed_4_t
=
Array
<
vllm_uint8b128_t
,
4
>
;
using
src_packed_2_t
=
Array
<
vllm_uint8b128_t
,
2
>
;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
using
ScalarConverter
=
NumericConverter
<
cutlass
::
bfloat16_t
,
vllm_uint8b128_t
,
Round
>
;
template
<
typename
PackedResultType
,
typename
PackedSrcType
>
CUTLASS_DEVICE
static
PackedResultType
packed_convert
(
PackedSrcType
const
&
source
)
{
static_assert
(
(
platform
::
is_same
<
PackedSrcType
,
src_packed_2_t
>::
value
&&
platform
::
is_same
<
PackedResultType
,
result_packed_2_t
>::
value
)
||
(
platform
::
is_same
<
PackedSrcType
,
src_packed_4_t
>::
value
&&
platform
::
is_same
<
PackedResultType
,
result_packed_4_t
>::
value
),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch."
);
NumericArrayConverter
<
float
,
vllm_uint8b128_t
,
PackedResultType
::
kElements
,
Round
>
convert_uint8_to_f32
;
Array
<
float
,
PackedResultType
::
kElements
>
tmp
=
convert_uint8_to_f32
(
source
);
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
float
,
PackedResultType
::
kElements
,
Round
>
convert_f32_to_bf16_
;
return
convert_f32_to_bf16_
(
tmp
);
}
friend
class
detail
::
VectorizedConverter
;
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
using
ConverterType
=
NumericArrayConverter
<
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
,
Round
>
;
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
csrc/layernorm_kernels.cu
View file @
ad385667
...
@@ -3,13 +3,16 @@
...
@@ -3,13 +3,16 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#else
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat162
=
__hip_bfloat162
;
...
@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
...
@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
variance
+=
x
*
x
;
}
}
variance
=
blockReduceSum
<
float
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
}
...
@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
...
@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
variance
+=
temp
.
sum_squares
();
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
residual_v
[
id
]
=
temp
;
}
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
if
(
num_tokens
<
256
)
{
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
}
...
@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
...
@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
variance
+=
x
*
x
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
if
(
num_tokens
<
256
)
{
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
}
...
...
csrc/mamba/causal_conv1d/causal_conv1d.cu
0 → 100644
View file @
ad385667
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "causal_conv1d.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include "static_switch.h"
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
void
set_conv_params_fwd
(
ConvParamsBase
&
params
,
// sizes
const
size_t
batch
,
const
size_t
dim
,
const
size_t
seqlen
,
const
size_t
width
,
// device pointers
const
at
::
Tensor
x
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
out
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
batch
=
batch
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
// Set the pointers and strides.
params
.
x_ptr
=
x
.
data_ptr
();
params
.
weight_ptr
=
weight
.
data_ptr
();
params
.
bias_ptr
=
bias
.
has_value
()
?
bias
.
value
().
data_ptr
()
:
nullptr
;
params
.
out_ptr
=
out
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
query_start_loc_ptr
=
query_start_loc
.
has_value
()
?
query_start_loc
.
value
().
data_ptr
()
:
nullptr
;
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
const
bool
varlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
params
.
x_batch_stride
=
x
.
stride
(
varlen
?
1
:
0
);
params
.
x_c_stride
=
x
.
stride
(
varlen
?
0
:
1
);
params
.
x_l_stride
=
x
.
stride
(
varlen
?
1
:
-
1
);
params
.
weight_c_stride
=
weight
.
stride
(
0
);
params
.
weight_width_stride
=
weight
.
stride
(
1
);
params
.
out_batch_stride
=
out
.
stride
(
varlen
?
1
:
0
);
params
.
out_c_stride
=
out
.
stride
(
varlen
?
0
:
1
);
params
.
out_l_stride
=
out
.
stride
(
varlen
?
1
:
-
1
);
}
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
bool
varlen
=
query_start_loc
.
has_value
()
?
true
:
false
;
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
varlen
?
query_start_loc
.
value
().
sizes
()[
0
]
-
1
:
sizes
[
0
];
const
int
dim
=
varlen
?
sizes
[
0
]
:
sizes
[
1
];
const
int
seqlen
=
varlen
?
sizes
[
1
]
:
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
if
(
varlen
){
CHECK_SHAPE
(
x
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
}
CHECK_SHAPE
(
weight
,
dim
,
width
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
if
(
has_initial_state
.
has_value
())
{
auto
has_initial_state_
=
has_initial_state
.
value
();
TORCH_CHECK
(
has_initial_state_
.
scalar_type
()
==
at
::
ScalarType
::
Bool
);
TORCH_CHECK
(
has_initial_state_
.
is_cuda
());
CHECK_SHAPE
(
has_initial_state_
,
batch_size
);
}
if
(
query_start_loc
.
has_value
())
{
auto
query_start_loc_
=
query_start_loc
.
value
();
TORCH_CHECK
(
query_start_loc_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
query_start_loc_
.
is_cuda
());
}
if
(
cache_indices
.
has_value
())
{
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
cache_indices
,
has_initial_state
);
if
(
conv_states
.
has_value
())
{
auto
conv_states_
=
conv_states
.
value
();
TORCH_CHECK
(
conv_states_
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
conv_states_
.
is_cuda
());
params
.
conv_states_ptr
=
conv_states_
.
data_ptr
();
params
.
conv_states_batch_stride
=
conv_states_
.
stride
(
0
);
params
.
conv_states_c_stride
=
conv_states_
.
stride
(
1
);
params
.
conv_states_l_stride
=
conv_states_
.
stride
(
2
);
}
else
{
params
.
conv_states_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
input_type
,
"weight type must equal to input type, other variations are disabled due to binary size limitations"
);
TORCH_CHECK
(
conv_state
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
conv_state
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
dim
=
sizes
[
1
];
const
int
seqlen
=
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
const
int
conv_state_len
=
conv_state
.
size
(
2
);
TORCH_CHECK
(
conv_state_len
>=
width
-
1
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
params
.
conv_state_batch_stride
=
conv_state
.
stride
(
0
);
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
if
(
cache_seqlens_
.
has_value
())
{
auto
cache_seqlens
=
cache_seqlens_
.
value
();
TORCH_CHECK
(
cache_seqlens
.
scalar_type
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cache_seqlens
.
is_cuda
());
TORCH_CHECK
(
cache_seqlens
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
cache_seqlens
,
batch_size
);
params
.
cache_seqlens
=
cache_seqlens
.
data_ptr
<
int32_t
>
();
}
else
{
params
.
cache_seqlens
=
nullptr
;
}
if
(
conv_state_indices_
.
has_value
())
{
auto
conv_state_indices
=
conv_state_indices_
.
value
();
TORCH_CHECK
(
conv_state_indices
.
scalar_type
()
==
torch
::
kInt32
)
TORCH_CHECK
(
conv_state_indices
.
is_cuda
());
TORCH_CHECK
(
conv_state_indices
.
stride
(
0
)
==
1
)
CHECK_SHAPE
(
conv_state_indices
,
batch_size
);
int
conv_state_entries
=
conv_state
.
size
(
0
);
CHECK_SHAPE
(
conv_state
,
conv_state_entries
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
conv_state_indices
.
data_ptr
<
int32_t
>
();
}
else
{
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_fwd_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
8
;
static_assert
(
kWidth
<=
kNElts
);
static
constexpr
bool
kIsVecLoad
=
kIsVecLoad_
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_STORE_DIRECT
>
;
static
constexpr
int
kSmemIOSize
=
kIsVecLoad
?
0
:
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
)});
static
constexpr
int
kSmemExchangeSize
=
kNThreads
*
kNBytes
*
kNElts
;
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
kSmemExchangeSize
;
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_fwd_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNElts
=
Ktraits
::
kNElts
;
constexpr
bool
kIsVecLoad
=
Ktraits
::
kIsVecLoad
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
vec_t
=
typename
Ktraits
::
vec_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_
);
vec_t
*
smem_exchange
=
reinterpret_cast
<
vec_t
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
;
const
int
*
query_start_loc
=
kVarlen
?
reinterpret_cast
<
int
*>
(
params
.
query_start_loc_ptr
)
:
nullptr
;
const
int
sequence_start_index
=
kVarlen
?
query_start_loc
[
batch_id
]
:
batch_id
;
const
int
seqlen
=
kVarlen
?
query_start_loc
[
batch_id
+
1
]
-
sequence_start_index
:
params
.
seqlen
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
sequence_start_index
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
bool
has_initial_state
=
params
.
has_initial_state_ptr
==
nullptr
?
false
:
reinterpret_cast
<
bool
*>
(
params
.
has_initial_state_ptr
)[
batch_id
];
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if
(
tidx
==
0
)
{
input_t
initial_state
[
kNElts
]
=
{
0
};
if
(
has_initial_state
)
{
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
initial_state
[
kNElts
-
1
-
(
kWidth
-
2
)
+
w
]
=
conv_states
[
w
];
}
}
smem_exchange
[
kNThreads
-
1
]
=
reinterpret_cast
<
vec_t
*>
(
initial_state
)[
0
];
}
float
weight_vals
[
kWidth
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
constexpr
int
kChunkSize
=
kNThreads
*
kNElts
;
const
int
n_chunks
=
(
seqlen
+
kChunkSize
-
1
)
/
kChunkSize
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
x_vals_load
[
2
*
kNElts
]
=
{
0
};
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
x
),
*
reinterpret_cast
<
vec_t
(
*
)[
1
]
>
(
&
x_vals_load
[
kNElts
]),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
__syncthreads
();
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
x
,
*
reinterpret_cast
<
input_t
(
*
)[
kNElts
]
>
(
&
x_vals_load
[
kNElts
]),
seqlen
-
chunk
*
kChunkSize
);
}
x
+=
kChunkSize
;
__syncthreads
();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if
(
tidx
<
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
__syncthreads
();
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
tidx
>
0
?
tidx
-
1
:
kNThreads
-
1
];
__syncthreads
();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if
(
tidx
==
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
float
x_vals
[
2
*
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
*
kNElts
;
++
i
)
{
x_vals
[
i
]
=
float
(
x_vals_load
[
i
]);
}
float
out_vals
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
bias_val
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
;
++
w
)
{
out_vals
[
i
]
+=
weight_vals
[
w
]
*
x_vals
[
kNElts
+
i
-
(
kWidth
-
w
-
1
)];
}
}
if
(
params
.
silu_activation
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
out_vals
[
i
]
/
(
1
+
expf
(
-
out_vals
[
i
]));
}
}
input_t
out_vals_store
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals_store
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
1
]
>
(
out_vals_store
),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
}
out
+=
kChunkSize
;
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int
last_thread
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
)
/
kNElts
;
if
(
conv_states
!=
nullptr
&&
tidx
==
last_thread
)
{
input_t
x_vals_load
[
kNElts
*
2
]
=
{
0
};
// in case we are on the first kWidth tokens
if
(
last_thread
==
0
&&
seqlen
<
kWidth
){
// Need to take the initial state
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
0
];
const
int
offset
=
seqlen
-
(
kWidth
-
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
// pad the existing state
if
((
w
-
seqlen
)
>=
0
&&
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
conv_states
[
w
];
}
else
if
((
w
-
seqlen
)
>=
0
&&
!
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
input_t
(
0.0
f
);
}
}
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
if
(
offset
+
w
>=
0
)
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
else
{
// in case the final state is in between the threads data
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
static
constexpr
int
kNElts
=
sizeof
(
input_t
)
==
4
?
4
:
8
;
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
BOOL_SWITCH
(
params
.
seqlen
%
kNElts
==
0
&&
!
kVarlen
,
kIsVecLoad
,
[
&
]
{
using
Ktraits
=
Causal_conv1d_fwd_kernel_traits
<
kNThreads
,
kWidth
,
kIsVecLoad
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
;
dim3
grid
(
params
.
batch
,
params
.
dim
);
auto
kernel
=
&
causal_conv1d_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
#ifndef USE_ROCM
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
std
::
cerr
<<
"Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.
\n
"
<<
std
::
endl
;
#endif
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_fwd_launch
<
128
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_fwd_launch
<
128
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_fwd_launch
<
128
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_fwd_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
int
kNThreads_
,
int
kWidth_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_update_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
};
template
<
typename
Ktraits
,
bool
kIsCircularBuffer
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_update_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
if
(
channel_id
>=
params
.
dim
)
return
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
int
state_len
=
params
.
conv_state_len
;
int
advance_len
=
params
.
seqlen
;
int
cache_seqlen
=
kIsCircularBuffer
?
params
.
cache_seqlens
[
batch_id
]
%
state_len
:
0
;
int
update_idx
=
cache_seqlen
-
(
kWidth
-
1
);
update_idx
=
update_idx
<
0
?
update_idx
+
state_len
:
update_idx
;
float
weight_vals
[
kWidth
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
float
x_vals
[
kWidth
]
=
{
0
};
if
constexpr
(
!
kIsCircularBuffer
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
state_len
-
advance_len
-
(
kWidth
-
1
);
++
i
)
{
conv_state
[
i
*
params
.
conv_state_l_stride
]
=
conv_state
[(
i
+
advance_len
)
*
params
.
conv_state_l_stride
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
input_t
state_val
=
conv_state
[(
state_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
];
if
(
i
<
advance_len
+
(
kWidth
-
1
)
&&
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
]
=
state_val
;
}
x_vals
[
i
]
=
float
(
state_val
);
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
,
update_idx
=
update_idx
+
1
>=
state_len
?
update_idx
+
1
-
state_len
:
update_idx
+
1
)
{
input_t
state_val
=
conv_state
[
update_idx
*
params
.
conv_state_l_stride
];
x_vals
[
i
]
=
float
(
state_val
);
}
}
#pragma unroll 2
for
(
int
i
=
0
;
i
<
params
.
seqlen
;
++
i
)
{
input_t
x_val
=
x
[
i
*
params
.
x_l_stride
];
if
constexpr
(
!
kIsCircularBuffer
)
{
if
(
i
<
advance_len
&&
state_len
-
advance_len
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
+
i
)
*
params
.
conv_state_l_stride
]
=
x_val
;
}
}
else
{
conv_state
[
update_idx
*
params
.
conv_state_l_stride
]
=
x_val
;
++
update_idx
;
update_idx
=
update_idx
>=
state_len
?
update_idx
-
state_len
:
update_idx
;
}
x_vals
[
kWidth
-
1
]
=
float
(
x_val
);
float
out_val
=
bias_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kWidth
;
++
j
)
{
out_val
+=
weight_vals
[
j
]
*
x_vals
[
j
];
}
if
(
params
.
silu_activation
)
{
out_val
=
out_val
/
(
1
+
expf
(
-
out_val
));
}
out
[
i
*
params
.
out_l_stride
]
=
input_t
(
out_val
);
// Shift the input buffer by 1
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
x_vals
[
i
]
=
x_vals
[
i
+
1
];
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
using
Ktraits
=
Causal_conv1d_update_kernel_traits
<
kNThreads
,
kWidth
,
input_t
,
weight_t
>
;
dim3
grid
(
params
.
batch
,
(
params
.
dim
+
kNThreads
-
1
)
/
kNThreads
);
auto
kernel
=
params
.
cache_seqlens
==
nullptr
?
&
causal_conv1d_update_kernel
<
Ktraits
,
false
>
:
&
causal_conv1d_update_kernel
<
Ktraits
,
true
>
;
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_update_launch
<
64
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_update_launch
<
64
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_update_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_update_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
csrc/mamba/causal_conv1d/causal_conv1d.h
0 → 100644
View file @
ad385667
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
int
conv_state_len
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
cache_indices_ptr
;
int32_t
*
__restrict__
cache_seqlens
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
void
*
conv_states_ptr
;
index_t
conv_states_batch_stride
;
index_t
conv_states_l_stride
;
index_t
conv_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
csrc/mamba/causal_conv1d/static_switch.h
0 → 100644
View file @
ad385667
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
csrc/mamba/mamba_ssm/selective_scan.h
0 → 100644
View file @
ad385667
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
SSMParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
dstate
,
n_groups
,
n_chunks
;
int
dim_ngroups_ratio
;
bool
is_variable_B
;
bool
is_variable_C
;
int64_t
pad_slot_id
;
bool
delta_softplus
;
index_t
A_d_stride
;
index_t
A_dstate_stride
;
index_t
B_batch_stride
;
index_t
B_d_stride
;
index_t
B_dstate_stride
;
index_t
B_group_stride
;
index_t
C_batch_stride
;
index_t
C_d_stride
;
index_t
C_dstate_stride
;
index_t
C_group_stride
;
index_t
u_batch_stride
;
index_t
u_d_stride
;
index_t
delta_batch_stride
;
index_t
delta_d_stride
;
index_t
z_batch_stride
;
index_t
z_d_stride
;
index_t
out_batch_stride
;
index_t
out_d_stride
;
index_t
out_z_batch_stride
;
index_t
out_z_d_stride
;
// Common data pointers.
void
*
__restrict__
A_ptr
;
void
*
__restrict__
B_ptr
;
void
*
__restrict__
C_ptr
;
void
*
__restrict__
D_ptr
;
void
*
__restrict__
u_ptr
;
void
*
__restrict__
delta_ptr
;
void
*
__restrict__
delta_bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
ssm_states_ptr
;
void
*
__restrict__
z_ptr
;
void
*
__restrict__
out_z_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
cache_indices_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
};
#ifndef USE_ROCM
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
#define MAX_DSTATE 256
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
inline
__device__
float3
operator
+
(
const
float3
&
a
,
const
float3
&
b
)
{
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
};
}
inline
__device__
float4
operator
+
(
const
float4
&
a
,
const
float4
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
scalar_t
,
int
N
>
struct
Converter
{
static
inline
__device__
void
to_float
(
const
scalar_t
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
};
template
<
int
N
>
struct
Converter
<
at
::
Half
,
N
>
{
static
inline
__device__
void
to_float
(
const
at
::
Half
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
static_assert
(
N
%
2
==
0
);
auto
&
src2
=
reinterpret_cast
<
const
half2
(
&
)[
N
/
2
]
>
(
src
);
auto
&
dst2
=
reinterpret_cast
<
float2
(
&
)[
N
/
2
]
>
(
dst
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N
/
2
;
++
i
)
{
dst2
[
i
]
=
__half22float2
(
src2
[
i
]);
}
}
};
#if __CUDA_ARCH__ >= 800
template
<
int
N
>
struct
Converter
<
at
::
BFloat16
,
N
>
{
static
inline
__device__
void
to_float
(
const
at
::
BFloat16
(
&
src
)[
N
],
float
(
&
dst
)[
N
])
{
static_assert
(
N
%
2
==
0
);
auto
&
src2
=
reinterpret_cast
<
const
nv_bfloat162
(
&
)[
N
/
2
]
>
(
src
);
auto
&
dst2
=
reinterpret_cast
<
float2
(
&
)[
N
/
2
]
>
(
dst
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N
/
2
;
++
i
)
{
dst2
[
i
]
=
__bfloat1622float2
(
src2
[
i
]);
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
scalar_t
>
struct
SSMScanOp
;
template
<
>
struct
SSMScanOp
<
float
>
{
__device__
__forceinline__
float2
operator
()(
const
float2
&
ab0
,
const
float2
&
ab1
)
const
{
return
make_float2
(
ab1
.
x
*
ab0
.
x
,
ab1
.
x
*
ab0
.
y
+
ab1
.
y
);
}
};
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template
<
typename
scalar_t
>
struct
SSMScanPrefixCallbackOp
{
using
scan_t
=
std
::
conditional_t
<
std
::
is_same_v
<
scalar_t
,
float
>
,
float2
,
float4
>
;
scan_t
running_prefix
;
// Constructor
__device__
SSMScanPrefixCallbackOp
(
scan_t
running_prefix_
)
:
running_prefix
(
running_prefix_
)
{}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__
scan_t
operator
()(
scan_t
block_aggregate
)
{
scan_t
old_prefix
=
running_prefix
;
running_prefix
=
SSMScanOp
<
scalar_t
>
()(
running_prefix
,
block_aggregate
);
return
old_prefix
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Ktraits
>
inline
__device__
void
load_input
(
typename
Ktraits
::
input_t
*
u
,
typename
Ktraits
::
input_t
(
&
u_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockLoadT
::
TempStorage
&
smem_load
,
int
seqlen
)
{
if
constexpr
(
Ktraits
::
kIsEvenLen
&&
!
Ktraits
::
kVarlen
)
{
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_load
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
u
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
u_vals
)
#ifdef USE_ROCM
,
Ktraits
::
kNThreads
*
Ktraits
::
kNLoads
#endif
);
}
else
{
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
u
,
u_vals
,
seqlen
,
0.
f
);
}
}
template
<
typename
Ktraits
>
inline
__device__
void
load_weight
(
typename
Ktraits
::
input_t
*
Bvar
,
typename
Ktraits
::
weight_t
(
&
B_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
&
smem_load_weight
,
int
seqlen
)
{
constexpr
int
kNItems
=
Ktraits
::
kNItems
;
typename
Ktraits
::
input_t
B_vals_load
[
kNItems
];
if
constexpr
(
Ktraits
::
kIsEvenLen
&&
!
Ktraits
::
kVarlen
)
{
auto
&
smem_load_weight_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightVecT
::
TempStorage
&>
(
smem_load_weight
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockLoadWeightVecT
(
smem_load_weight_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
Bvar
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
B_vals_load
)
);
}
else
{
typename
Ktraits
::
BlockLoadWeightT
(
smem_load_weight
).
Load
(
Bvar
,
B_vals_load
,
seqlen
,
0.
f
);
}
// #pragma unroll
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
Converter
<
typename
Ktraits
::
input_t
,
kNItems
>::
to_float
(
B_vals_load
,
B_vals
);
}
template
<
typename
Ktraits
>
inline
__device__
void
store_output
(
typename
Ktraits
::
input_t
*
out
,
const
float
(
&
out_vals
)[
Ktraits
::
kNItems
],
typename
Ktraits
::
BlockStoreT
::
TempStorage
&
smem_store
,
int
seqlen
)
{
typename
Ktraits
::
input_t
write_vals
[
Ktraits
::
kNItems
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Ktraits
::
kNItems
;
++
i
)
{
write_vals
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
Ktraits
::
kIsEvenLen
&&
!
Ktraits
::
kVarlen
)
{
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_store
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
Ktraits
::
kNLoads
]
>
(
write_vals
)
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
write_vals
,
seqlen
);
}
}
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
0 → 100644
View file @
ad385667
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "selective_scan.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "selective_scan.h"
#include "static_switch.h"
template
<
int
kNThreads_
,
int
kNItems_
,
int
kNRows_
,
bool
kIsEvenLen_
,
bool
kIsVariableB_
,
bool
kIsVariableC_
,
bool
kHasZ_
,
bool
kVarlen_
,
typename
input_t_
,
typename
weight_t_
>
struct
Selective_Scan_fwd_kernel_traits
{
static_assert
(
kNItems_
%
4
==
0
);
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
static
constexpr
int
kMinBlocks
=
kNThreads
<
128
?
5
:
3
;
static
constexpr
int
kNItems
=
kNItems_
;
static
constexpr
int
kNRows
=
kNRows_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
constexpr_min
(
8
,
kNItems
);
static_assert
(
kNItems
%
kNElts
==
0
);
static
constexpr
int
kNLoads
=
kNItems
/
kNElts
;
static
constexpr
bool
kIsEvenLen
=
kVarlen_
?
false
:
kIsEvenLen_
;
static
constexpr
bool
kIsVariableB
=
kIsVariableB_
;
static
constexpr
bool
kIsVariableC
=
kIsVariableC_
;
static
constexpr
bool
kHasZ
=
kHasZ_
;
static
constexpr
bool
kVarlen
=
kVarlen_
;
static
constexpr
bool
kDirectIO
=
kVarlen_
?
false
:
kIsEvenLen
&&
kNLoads
==
1
;
static
constexpr
int
kNLoadsIndex
=
kNItems
/
4
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
scan_t
=
float2
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNItems
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
kNLoads
,
!
kDirectIO
?
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
:
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockLoadWeightT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNItems
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadWeightVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
kNLoads
,
!
kDirectIO
?
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
:
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNItems
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
kNLoads
,
!
kDirectIO
?
cub
::
BLOCK_STORE_WARP_TRANSPOSE
:
cub
::
BLOCK_STORE_DIRECT
>
;
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
using
BlockScanT
=
cub
::
BlockScan
<
scan_t
,
kNThreads
,
cub
::
BLOCK_SCAN_WARP_SCANS
>
;
static
constexpr
int
kSmemIOSize
=
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockLoadVecT
::
TempStorage
),
(
int
(
kIsVariableB
)
+
int
(
kIsVariableC
))
*
sizeof
(
typename
BlockLoadWeightT
::
TempStorage
),
(
int
(
kIsVariableB
)
+
int
(
kIsVariableC
))
*
sizeof
(
typename
BlockLoadWeightVecT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
),
sizeof
(
typename
BlockStoreVecT
::
TempStorage
)});
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
sizeof
(
typename
BlockScanT
::
TempStorage
);
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
,
Ktraits
::
kMinBlocks
)
void
selective_scan_fwd_kernel
(
SSMParamsBase
params
)
{
constexpr
bool
kIsVariableB
=
Ktraits
::
kIsVariableB
;
constexpr
bool
kIsVariableC
=
Ktraits
::
kIsVariableC
;
constexpr
bool
kHasZ
=
Ktraits
::
kHasZ
;
constexpr
bool
kVarlen
=
Ktraits
::
kVarlen
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNItems
=
Ktraits
::
kNItems
;
constexpr
int
kNRows
=
Ktraits
::
kNRows
;
constexpr
bool
kDirectIO
=
Ktraits
::
kDirectIO
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
using
scan_t
=
typename
Ktraits
::
scan_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
// cast to lvalue reference of expected type
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_weight
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_weight1
=
*
reinterpret_cast
<
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
*>
(
smem_
+
sizeof
(
typename
Ktraits
::
BlockLoadWeightT
::
TempStorage
));
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_scan
=
*
reinterpret_cast
<
typename
Ktraits
::
BlockScanT
::
TempStorage
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
scan_t
*
smem_running_prefix
=
reinterpret_cast
<
scan_t
*>
(
smem_
+
Ktraits
::
kSmemSize
);
const
int
batch_id
=
blockIdx
.
x
;
const
int
dim_id
=
blockIdx
.
y
;
const
int
group_id
=
dim_id
/
(
params
.
dim_ngroups_ratio
);
int
seqlen
=
params
.
seqlen
;
int
sequence_start_index
=
batch_id
;
if
constexpr
(
kVarlen
){
int
*
query_start_loc
=
reinterpret_cast
<
int
*>
(
params
.
query_start_loc_ptr
);
sequence_start_index
=
query_start_loc
[
batch_id
];
seqlen
=
query_start_loc
[
batch_id
+
1
]
-
sequence_start_index
;
}
const
bool
has_initial_state
=
params
.
has_initial_state_ptr
==
nullptr
?
false
:
reinterpret_cast
<
bool
*>
(
params
.
has_initial_state_ptr
)[
batch_id
];
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
u
=
reinterpret_cast
<
input_t
*>
(
params
.
u_ptr
)
+
sequence_start_index
*
params
.
u_batch_stride
+
dim_id
*
kNRows
*
params
.
u_d_stride
;
input_t
*
delta
=
reinterpret_cast
<
input_t
*>
(
params
.
delta_ptr
)
+
sequence_start_index
*
params
.
delta_batch_stride
+
dim_id
*
kNRows
*
params
.
delta_d_stride
;
weight_t
*
A
=
reinterpret_cast
<
weight_t
*>
(
params
.
A_ptr
)
+
dim_id
*
kNRows
*
params
.
A_d_stride
;
weight_t
*
B
=
reinterpret_cast
<
weight_t
*>
(
params
.
B_ptr
)
+
dim_id
*
kNRows
*
params
.
B_d_stride
;
input_t
*
Bvar
=
reinterpret_cast
<
input_t
*>
(
params
.
B_ptr
)
+
sequence_start_index
*
params
.
B_batch_stride
+
group_id
*
params
.
B_group_stride
;
weight_t
*
C
=
reinterpret_cast
<
weight_t
*>
(
params
.
C_ptr
)
+
dim_id
*
kNRows
*
params
.
C_d_stride
;
input_t
*
Cvar
=
reinterpret_cast
<
input_t
*>
(
params
.
C_ptr
)
+
sequence_start_index
*
params
.
C_batch_stride
+
group_id
*
params
.
C_group_stride
;
input_t
*
ssm_states
=
reinterpret_cast
<
input_t
*>
(
params
.
ssm_states_ptr
)
+
(
cache_index
*
params
.
dim
+
dim_id
*
kNRows
)
*
params
.
dstate
;
float
D_val
[
kNRows
]
=
{
0
};
if
(
params
.
D_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
D_val
[
r
]
=
reinterpret_cast
<
float
*>
(
params
.
D_ptr
)[
dim_id
*
kNRows
+
r
];
}
}
float
delta_bias
[
kNRows
]
=
{
0
};
if
(
params
.
delta_bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
delta_bias
[
r
]
=
reinterpret_cast
<
float
*>
(
params
.
delta_bias_ptr
)[
dim_id
*
kNRows
+
r
];
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr
int
kChunkSize
=
kNThreads
*
kNItems
;
const
int
n_chunks
=
(
seqlen
+
2048
-
1
)
/
2048
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
u_vals
[
kNRows
][
kNItems
],
delta_vals_load
[
kNRows
][
kNItems
];
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
load_input
<
Ktraits
>
(
u
+
r
*
params
.
u_d_stride
,
u_vals
[
r
],
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
if
constexpr
(
!
kDirectIO
)
{
__syncthreads
();
}
load_input
<
Ktraits
>
(
delta
+
r
*
params
.
delta_d_stride
,
delta_vals_load
[
r
],
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
}
u
+=
kChunkSize
;
delta
+=
kChunkSize
;
float
delta_vals
[
kNRows
][
kNItems
],
delta_u_vals
[
kNRows
][
kNItems
],
out_vals
[
kNRows
][
kNItems
];
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
float
u_val
=
float
(
u_vals
[
r
][
i
]);
delta_vals
[
r
][
i
]
=
float
(
delta_vals_load
[
r
][
i
])
+
delta_bias
[
r
];
if
(
params
.
delta_softplus
)
{
delta_vals
[
r
][
i
]
=
delta_vals
[
r
][
i
]
<=
20.
f
?
log1pf
(
expf
(
delta_vals
[
r
][
i
]))
:
delta_vals
[
r
][
i
];
}
delta_u_vals
[
r
][
i
]
=
delta_vals
[
r
][
i
]
*
u_val
;
out_vals
[
r
][
i
]
=
D_val
[
r
]
*
u_val
;
}
}
__syncthreads
();
for
(
int
state_idx
=
0
;
state_idx
<
params
.
dstate
;
++
state_idx
)
{
weight_t
A_val
[
kNRows
];
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
A_val
[
r
]
=
A
[
state_idx
*
params
.
A_dstate_stride
+
r
*
params
.
A_d_stride
];
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
constexpr
float
kLog2e
=
M_LOG2E
;
A_val
[
r
]
*=
kLog2e
;
}
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
// If both B and C vary, this is unused.
weight_t
BC_val
[
kNRows
];
weight_t
B_vals
[
kNItems
],
C_vals
[
kNItems
];
if
constexpr
(
kIsVariableB
)
{
load_weight
<
Ktraits
>
(
Bvar
+
state_idx
*
params
.
B_dstate_stride
,
B_vals
,
smem_load_weight
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
if
constexpr
(
!
kIsVariableC
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
C
[
state_idx
*
params
.
C_dstate_stride
+
r
*
params
.
C_d_stride
];
}
}
}
if
constexpr
(
kIsVariableC
)
{
auto
&
smem_load_weight_C
=
!
kIsVariableB
?
smem_load_weight
:
smem_load_weight1
;
load_weight
<
Ktraits
>
(
Cvar
+
state_idx
*
params
.
C_dstate_stride
,
C_vals
,
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
if
constexpr
(
!
kIsVariableB
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
B
[
state_idx
*
params
.
B_dstate_stride
+
r
*
params
.
B_d_stride
];
}
}
}
if
constexpr
(
!
kIsVariableB
&&
!
kIsVariableC
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
BC_val
[
r
]
=
B
[
state_idx
*
params
.
B_dstate_stride
+
r
*
params
.
B_d_stride
]
*
C
[
state_idx
*
params
.
C_dstate_stride
+
r
*
params
.
C_d_stride
];
}
}
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
// Scan could be using the same smem
scan_t
thread_data
[
kNItems
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
thread_data
[
i
]
=
make_float2
(
exp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]),
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
]);
if
(
seqlen
%
(
kNItems
*
kNThreads
)
!=
0
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
}
}
}
// Initialize running total
scan_t
running_prefix
=
chunk
>
0
?
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
:
make_float2
(
1.0
,
has_initial_state
?
float
(
ssm_states
[
state_idx
])
:
0.0
);
SSMScanPrefixCallbackOp
<
weight_t
>
prefix_op
(
running_prefix
);
typename
Ktraits
::
BlockScanT
(
smem_scan
).
InclusiveScan
(
thread_data
,
thread_data
,
SSMScanOp
<
weight_t
>
(),
prefix_op
);
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
]
=
prefix_op
.
running_prefix
;
if
(
chunk
==
n_chunks
-
1
)
{
ssm_states
[
state_idx
]
=
input_t
(
prefix_op
.
running_prefix
.
y
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
const
weight_t
C_val
=
!
kIsVariableC
?
BC_val
[
r
]
:
(
!
kIsVariableB
?
BC_val
[
r
]
*
C_vals
[
i
]
:
C_vals
[
i
]);
out_vals
[
r
][
i
]
+=
thread_data
[
i
].
y
*
C_val
;
}
}
}
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
chunk
*
kChunkSize
;
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
if
constexpr
(
!
kDirectIO
)
{
if
(
r
>
0
)
{
__syncthreads
();
}
}
store_output
<
Ktraits
>
(
out
+
r
*
params
.
out_d_stride
,
out_vals
[
r
],
smem_store
,
seqlen
-
chunk
*
kChunkSize
);
}
if
constexpr
(
kHasZ
)
{
input_t
*
z
=
reinterpret_cast
<
input_t
*>
(
params
.
z_ptr
)
+
sequence_start_index
*
params
.
z_batch_stride
+
dim_id
*
kNRows
*
params
.
z_d_stride
+
chunk
*
kChunkSize
;
input_t
*
out_z
=
reinterpret_cast
<
input_t
*>
(
params
.
out_z_ptr
)
+
sequence_start_index
*
params
.
out_z_batch_stride
+
dim_id
*
kNRows
*
params
.
out_z_d_stride
+
chunk
*
kChunkSize
;
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
input_t
z_vals
[
kNItems
];
__syncthreads
();
load_input
<
Ktraits
>
(
z
+
r
*
params
.
z_d_stride
,
z_vals
,
smem_load
,
seqlen
-
chunk
*
kChunkSize
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
float
z_val
=
z_vals
[
i
];
out_vals
[
r
][
i
]
*=
z_val
/
(
1
+
expf
(
-
z_val
));
}
__syncthreads
();
store_output
<
Ktraits
>
(
out_z
+
r
*
params
.
out_z_d_stride
,
out_vals
[
r
],
smem_store
,
seqlen
-
chunk
*
kChunkSize
);
}
}
Bvar
+=
kChunkSize
*
1
;
Cvar
+=
kChunkSize
*
1
;
}
}
template
<
int
kNThreads
,
int
kNItems
,
typename
input_t
,
typename
weight_t
>
void
selective_scan_fwd_launch
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
constexpr
int
kNRows
=
1
;
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr
bool
kIsVariableB
=
true
;
constexpr
bool
kIsVariableC
=
true
;
constexpr
bool
kHasZ
=
true
;
BOOL_SWITCH
(
params
.
seqlen
%
(
kNThreads
*
kNItems
)
==
0
,
kIsEvenLen
,
[
&
]
{
BOOL_SWITCH
(
params
.
query_start_loc_ptr
!=
nullptr
,
kVarlen
,
[
&
]
{
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
kVarlen
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
+
kNRows
*
MAX_DSTATE
*
sizeof
(
typename
Ktraits
::
scan_t
);
dim3
grid
(
params
.
batch
,
params
.
dim
/
kNRows
);
auto
kernel
=
&
selective_scan_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
#ifndef USE_ROCM
if
(
params
.
seqlen
<=
128
)
{
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
32
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
#else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
{
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
>
(
params
,
stream
);
}
#endif
}
template
void
selective_scan_fwd_cuda
<
at
::
BFloat16
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
at
::
Half
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
template
void
selective_scan_fwd_cuda
<
float
,
float
>(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = float; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = float; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
);
void
set_ssm_params_fwd
(
SSMParamsBase
&
params
,
// sizes
const
size_t
batch
,
const
size_t
dim
,
const
size_t
seqlen
,
const
size_t
dstate
,
const
size_t
n_groups
,
const
bool
is_variable_B
,
const
bool
is_variable_C
,
// device pointers
const
torch
::
Tensor
u
,
const
torch
::
Tensor
delta
,
const
torch
::
Tensor
A
,
const
torch
::
Tensor
B
,
const
torch
::
Tensor
C
,
const
torch
::
Tensor
out
,
const
torch
::
Tensor
z
,
const
torch
::
Tensor
out_z
,
const
c10
::
optional
<
at
::
Tensor
>&
D
,
const
c10
::
optional
<
at
::
Tensor
>&
delta_bias
,
const
torch
::
Tensor
ssm_states
,
bool
has_z
,
bool
delta_softplus
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
,
int64_t
pad_slot_id
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
batch
=
batch
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
dstate
=
dstate
;
params
.
n_groups
=
n_groups
;
params
.
dim_ngroups_ratio
=
dim
/
n_groups
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
delta_softplus
=
delta_softplus
;
params
.
is_variable_B
=
is_variable_B
;
params
.
is_variable_C
=
is_variable_C
;
// Set the pointers and strides.
params
.
u_ptr
=
u
.
data_ptr
();
params
.
delta_ptr
=
delta
.
data_ptr
();
params
.
A_ptr
=
A
.
data_ptr
();
params
.
B_ptr
=
B
.
data_ptr
();
params
.
C_ptr
=
C
.
data_ptr
();
params
.
D_ptr
=
D
.
has_value
()
?
D
.
value
().
data_ptr
()
:
nullptr
;
params
.
delta_bias_ptr
=
delta_bias
.
has_value
()
?
delta_bias
.
value
().
data_ptr
()
:
nullptr
;
params
.
out_ptr
=
out
.
data_ptr
();
params
.
ssm_states_ptr
=
ssm_states
.
data_ptr
();
params
.
z_ptr
=
has_z
?
z
.
data_ptr
()
:
nullptr
;
params
.
out_z_ptr
=
has_z
?
out_z
.
data_ptr
()
:
nullptr
;
params
.
query_start_loc_ptr
=
query_start_loc
.
has_value
()
?
query_start_loc
.
value
().
data_ptr
()
:
nullptr
;
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
// All stride are in elements, not bytes.
params
.
A_d_stride
=
A
.
stride
(
0
);
params
.
A_dstate_stride
=
A
.
stride
(
1
);
if
(
varlen
){
params
.
B_batch_stride
=
B
.
stride
(
2
);
params
.
B_group_stride
=
B
.
stride
(
0
);
params
.
B_dstate_stride
=
B
.
stride
(
1
);
params
.
C_batch_stride
=
C
.
stride
(
2
);
params
.
C_group_stride
=
C
.
stride
(
0
);
params
.
C_dstate_stride
=
C
.
stride
(
1
);
params
.
u_batch_stride
=
u
.
stride
(
1
);
params
.
u_d_stride
=
u
.
stride
(
0
);
params
.
delta_batch_stride
=
delta
.
stride
(
1
);
params
.
delta_d_stride
=
delta
.
stride
(
0
);
if
(
has_z
)
{
params
.
z_batch_stride
=
z
.
stride
(
1
);
params
.
z_d_stride
=
z
.
stride
(
0
);
params
.
out_z_batch_stride
=
out_z
.
stride
(
1
);
params
.
out_z_d_stride
=
out_z
.
stride
(
0
);
}
params
.
out_batch_stride
=
out
.
stride
(
1
);
params
.
out_d_stride
=
out
.
stride
(
0
);
}
else
{
if
(
!
is_variable_B
)
{
params
.
B_d_stride
=
B
.
stride
(
0
);
}
else
{
params
.
B_batch_stride
=
B
.
stride
(
0
);
params
.
B_group_stride
=
B
.
stride
(
1
);
}
params
.
B_dstate_stride
=
!
is_variable_B
?
B
.
stride
(
1
)
:
B
.
stride
(
2
);
if
(
!
is_variable_C
)
{
params
.
C_d_stride
=
C
.
stride
(
0
);
}
else
{
params
.
C_batch_stride
=
C
.
stride
(
0
);
params
.
C_group_stride
=
C
.
stride
(
1
);
}
params
.
C_dstate_stride
=
!
is_variable_C
?
C
.
stride
(
1
)
:
C
.
stride
(
2
);
params
.
u_batch_stride
=
u
.
stride
(
0
);
params
.
u_d_stride
=
u
.
stride
(
1
);
params
.
delta_batch_stride
=
delta
.
stride
(
0
);
params
.
delta_d_stride
=
delta
.
stride
(
1
);
if
(
has_z
)
{
params
.
z_batch_stride
=
z
.
stride
(
0
);
params
.
z_d_stride
=
z
.
stride
(
1
);
params
.
out_z_batch_stride
=
out_z
.
stride
(
0
);
params
.
out_z_d_stride
=
out_z
.
stride
(
1
);
}
params
.
out_batch_stride
=
out
.
stride
(
0
);
params
.
out_d_stride
=
out
.
stride
(
1
);
}
}
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>
&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>
&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>
&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
);
const
bool
is_variable_B
=
B
.
dim
()
>=
3
;
const
bool
is_variable_C
=
C
.
dim
()
>=
3
;
TORCH_CHECK
(
delta
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
B
.
scalar_type
()
==
(
!
is_variable_B
?
weight_type
:
input_type
));
TORCH_CHECK
(
C
.
scalar_type
()
==
(
!
is_variable_C
?
weight_type
:
input_type
));
TORCH_CHECK
(
u
.
is_cuda
());
TORCH_CHECK
(
delta
.
is_cuda
());
TORCH_CHECK
(
A
.
is_cuda
());
TORCH_CHECK
(
B
.
is_cuda
());
TORCH_CHECK
(
C
.
is_cuda
());
TORCH_CHECK
(
u
.
stride
(
-
1
)
==
1
||
u
.
size
(
-
1
)
==
1
);
TORCH_CHECK
(
delta
.
stride
(
-
1
)
==
1
||
delta
.
size
(
-
1
)
==
1
);
const
auto
sizes
=
u
.
sizes
();
const
bool
varlen
=
query_start_loc
.
has_value
();
const
int
batch_size
=
varlen
?
query_start_loc
.
value
().
sizes
()[
0
]
-
1
:
sizes
[
0
];
const
int
dim
=
varlen
?
sizes
[
0
]
:
sizes
[
1
];
const
int
seqlen
=
varlen
?
sizes
[
1
]
:
sizes
[
2
];
const
int
dstate
=
A
.
size
(
1
);
const
int
n_groups
=
varlen
?
B
.
size
(
0
)
:
B
.
size
(
1
);
TORCH_CHECK
(
dstate
<=
256
,
"selective_scan only supports state dimension <= 256"
);
if
(
varlen
)
{
CHECK_SHAPE
(
u
,
dim
,
seqlen
);
CHECK_SHAPE
(
delta
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
u
,
batch_size
,
dim
,
seqlen
);
CHECK_SHAPE
(
delta
,
batch_size
,
dim
,
seqlen
);
}
CHECK_SHAPE
(
A
,
dim
,
dstate
);
TORCH_CHECK
(
is_variable_B
,
"is_variable_B = False is disabled in favor of reduced binary size"
)
if
(
varlen
)
{
CHECK_SHAPE
(
B
,
n_groups
,
dstate
,
seqlen
);
}
else
{
CHECK_SHAPE
(
B
,
batch_size
,
n_groups
,
dstate
,
seqlen
);
}
TORCH_CHECK
(
B
.
stride
(
-
1
)
==
1
||
B
.
size
(
-
1
)
==
1
);
TORCH_CHECK
(
is_variable_C
,
"is_variable_C = False is disabled in favor of reduced binary size"
)
if
(
varlen
)
{
CHECK_SHAPE
(
C
,
n_groups
,
dstate
,
seqlen
);
}
else
{
CHECK_SHAPE
(
C
,
batch_size
,
n_groups
,
dstate
,
seqlen
);
}
TORCH_CHECK
(
C
.
stride
(
-
1
)
==
1
||
C
.
size
(
-
1
)
==
1
);
if
(
D_
.
has_value
())
{
auto
D
=
D_
.
value
();
TORCH_CHECK
(
D
.
scalar_type
()
==
at
::
ScalarType
::
Float
);
TORCH_CHECK
(
D
.
is_cuda
());
TORCH_CHECK
(
D
.
stride
(
-
1
)
==
1
||
D
.
size
(
-
1
)
==
1
);
CHECK_SHAPE
(
D
,
dim
);
}
if
(
delta_bias_
.
has_value
())
{
auto
delta_bias
=
delta_bias_
.
value
();
TORCH_CHECK
(
delta_bias
.
scalar_type
()
==
at
::
ScalarType
::
Float
);
TORCH_CHECK
(
delta_bias
.
is_cuda
());
TORCH_CHECK
(
delta_bias
.
stride
(
-
1
)
==
1
||
delta_bias
.
size
(
-
1
)
==
1
);
CHECK_SHAPE
(
delta_bias
,
dim
);
}
if
(
has_initial_state
.
has_value
())
{
auto
has_initial_state_
=
has_initial_state
.
value
();
TORCH_CHECK
(
has_initial_state_
.
scalar_type
()
==
at
::
ScalarType
::
Bool
);
TORCH_CHECK
(
has_initial_state_
.
is_cuda
());
CHECK_SHAPE
(
has_initial_state_
,
batch_size
);
}
if
(
query_start_loc
.
has_value
())
{
auto
query_start_loc_
=
query_start_loc
.
value
();
TORCH_CHECK
(
query_start_loc_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
query_start_loc_
.
is_cuda
());
}
if
(
cache_indices
.
has_value
())
{
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
z
,
out_z
;
const
bool
has_z
=
z_
.
has_value
();
TORCH_CHECK
(
has_z
,
"has_z = False is disabled in favor of reduced binary size"
)
z
=
z_
.
value
();
TORCH_CHECK
(
z
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
z
.
is_cuda
());
TORCH_CHECK
(
z
.
stride
(
-
1
)
==
1
||
z
.
size
(
-
1
)
==
1
);
if
(
varlen
){
CHECK_SHAPE
(
z
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
z
,
batch_size
,
dim
,
seqlen
);
}
out_z
=
z
;
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at
::
Tensor
out
=
delta
;
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
ssm_states
.
is_cuda
());
TORCH_CHECK
(
ssm_states
.
stride
(
-
1
)
==
1
);
SSMParamsBase
params
;
set_ssm_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
dstate
,
n_groups
,
is_variable_B
,
is_variable_C
,
u
,
delta
,
A
,
B
,
C
,
out
,
z
,
out_z
,
D_
,
delta_bias_
,
ssm_states
,
has_z
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
varlen
,
pad_slot_id
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
u
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
u
.
scalar_type
(),
"selective_scan_fwd"
,
[
&
]
{
selective_scan_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
csrc/mamba/mamba_ssm/static_switch.h
0 → 100644
View file @
ad385667
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
csrc/moe/marlin_kernels/marlin_moe_kernel.h
0 → 100644
View file @
ad385667
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/scalar_type.hpp"
namespace
marlin_moe
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
using
FragZP
=
Vec
<
half2
,
4
>
;
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
vllm
::
ScalarTypeId
w_type_id
>
__device__
inline
FragB
dequant
(
int
q
);
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
template
<
>
__device__
inline
FragB
dequant
<
vllm
::
kU4B8
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
// Reference:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
template
<
>
__device__
inline
FragB
dequant
<
vllm
::
kU8B128
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
FragB
dequant
<
vllm
::
kU4
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
FragB
dequant
<
vllm
::
kU8
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
__device__
inline
void
sub_zp
(
FragB
&
frag_b
,
half2
&
frag_zp
,
int
i
)
{
half2
zp
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
__device__
inline
void
scale4
(
FragB
&
frag_b
,
FragS
&
frag_s_1
,
FragS
&
frag_s_2
,
FragS
&
frag_s_3
,
FragS
&
frag_s_4
,
int
i
)
{
__half2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_2
)[
i
];
__half2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Given 2 floats multiply by 2 scales (halves)
__device__
inline
void
scale_float
(
float
*
c
,
FragS
&
s
)
{
__half
*
s_ptr
=
reinterpret_cast
<
__half
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
__half2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
__half2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__device__
void
MarlinMoESingle
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
// current m block to start kernel computation from
)
{
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
ceildiv
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
ceildiv
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
sorted_ids
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
ceildiv
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
ceildiv
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
sorted_ids
+=
16
*
thread_m_blocks
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
ceildiv
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
w_type
.
size_bits
()
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_sh_wr_delta
*
i
+
a_sh_wr
;
int
row
=
a_idx
/
a_gl_rd_delta_o
;
if
(
row
>=
prob_m
)
{
a_sh_wr_pred
[
i
]
=
false
;
}
else
{
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
}
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
;
int
row
=
a_idx
/
a_gl_stride
;
int
sorted_row
=
replicate_input
?
sorted_ids
[
row
]
/
topk
:
sorted_ids
[
row
];
int
new_idx
=
sorted_row
*
a_gl_stride
+
a_idx
%
a_gl_stride
;
if
(
sorted_row
<
tot_m
*
(
replicate_input
?
1
:
topk
)
&&
new_idx
<
a_gl_stride
*
tot_m
*
(
replicate_input
?
1
:
topk
))
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
new_idx
],
a_sh_wr_pred
[
i
]);
}
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
auto
fetch_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
int
zp_quant_0
,
zp_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
zp_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
}
frag_zp_0
=
dequant
<
w_type_id
>
(
zp_quant_0
);
frag_zp_1
=
dequant
<
w_type_id
>
(
zp_quant_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k
%
2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
FragB
frag_b0
=
dequant
<
w_type_id
>
(
b_quant_0
);
FragB
frag_b1
=
dequant
<
w_type_id
>
(
b_quant_1
);
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
sub_zp
(
frag_b0
,
frag_zp
[
j
],
0
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
)
{
sub_zp
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
int
c_idx
=
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
);
int
sorted_row
=
sorted_ids
[
c_idx
/
c_gl_stride
];
int
new_idx
=
sorted_row
*
c_gl_stride
+
c_idx
%
c_gl_stride
;
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
new_idx
],
sorted_row
<
tot_m
*
topk
&&
(
8
*
(
i
/
2
)
+
row
<
prob_m
&&
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
sorted_ids
[
8
*
(
i
/
2
)
+
row
]
<
tot_m
*
topk
)));
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
8
*
(
i
/
2
)
+
row
<
prob_m
&&
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
sorted_ids
[
8
*
(
i
/
2
)
+
row
]
<
tot_m
*
topk
))
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
int
c_idx
=
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
);
int
row
=
sorted_ids
[
c_idx
/
c_gl_stride
];
if
(
row
<
tot_m
*
topk
)
{
int
new_idx
=
row
*
c_gl_stride
+
c_idx
%
c_gl_stride
;
C
[
new_idx
]
=
c
;
}
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
half2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
ceildiv
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
int
row
=
sorted_ids
[
c_gl_wr
/
c_gl_stride
];
if
(
row
<
tot_m
*
topk
)
{
int
off
=
row
*
c_gl_stride
+
c_gl_wr
%
c_gl_stride
;
if
(
!
apply_weights
)
{
C
[
off
]
=
sh
[
c_sh_rd
];
}
else
{
__half
*
ctrg
=
reinterpret_cast
<
__half
*>
(
&
C
[
off
]);
__half
*
csrc
=
reinterpret_cast
<
__half
*>
(
&
sh
[
c_sh_rd
]);
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
ctrg
[
j
]
=
__float2half
(
topk_weights
[
row
]
*
__half2float
(
csrc
[
j
]));
}
}
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
if
constexpr
(
has_zp
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
w_type
.
size_bits
()
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
else
{
// For 4-bit per-column scales, we only fetch them here in the
// final step before write-out
if
(
last
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
w_type
.
size_bits
()
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
{
if
(
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
template
<
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
MarlinMoE
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids_base
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
,
// current m block to start kernel computation from
int
max_par
,
// maximum parallelism
int
cfg_max_m_blocks
// upper bound on m blocks
)
{
int
m_block_ctr
=
current_m_block
;
const
int
*
sorted_ids_expert
=
sorted_ids_base
+
expert_offsets
[
expert_idx
]
+
m_block_ctr
*
4
*
max_par
;
int
tot_its
=
expert_offsets
[
expert_idx
+
1
]
-
expert_offsets
[
expert_idx
];
if
(
tot_its
==
0
)
{
return
;
}
int
tot_m_blocks
=
ceildiv
(
tot_its
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_its
;
if
(
m_block_ctr
>=
tot_m_blocks
)
{
return
;
}
int
max_block
=
tot_m_blocks
-
m_block_ctr
;
prob_m
=
tot_its
-
16
*
m_block_ctr
;
int
par
=
1
;
if
(
max_block
>
cfg_max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
max_block
-
pad
)
/
(
16
*
cfg_max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
cfg_max_m_blocks
)
*
par
;
m_block_ctr
+=
cfg_max_m_blocks
*
(
par
-
1
);
max_block
=
cfg_max_m_blocks
;
}
if
(
max_block
==
1
)
{
MarlinMoESingle
<
w_type_id
,
threads
,
1
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
has_zp
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
zp_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
if
(
max_block
==
2
)
{
MarlinMoESingle
<
w_type_id
,
threads
,
2
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
has_zp
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
zp_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
if
(
max_block
==
3
)
{
MarlinMoESingle
<
w_type_id
,
threads
,
3
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
has_zp
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
zp_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
{
MarlinMoESingle
<
w_type_id
,
threads
,
4
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
has_zp
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
zp_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
}
#else
template
<
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
MarlinMoE
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
,
// current m block to start kernel computation from
int
max_par
,
// maximum parallelism
int
cfg_max_m_blocks
// upper bound on m blocks
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const
int
USER_THREADS
=
256
;
// Note: This is only used with user-provided thread_k/n
const
int
STAGES
=
4
;
// 4 pipeline stages fit into shared memory
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
cfg_max_m_blocks); \
}
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
0 → 100644
View file @
ad385667
#include "marlin_moe_kernel_ku4.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
true
;
if
(
false
)
{
}
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
16
,
4
,
256
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
8
,
4
,
128
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
0 → 100644
View file @
ad385667
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
0 → 100644
View file @
ad385667
#include "marlin_moe_kernel_ku4b8.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
0 → 100644
View file @
ad385667
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
0 → 100644
View file @
ad385667
#include "marlin_moe_kernel_ku8b128.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
0 → 100644
View file @
ad385667
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
csrc/moe/marlin_moe_ops.cu
0 → 100644
View file @
ad385667
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin_moe
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
blockDim
.
x
;
int
rest
=
size_k
%
blockDim
.
x
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
blockDim
.
x
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
int
expert_id
=
threadIdx
.
x
;
int
num_experts
=
blockDim
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
0
;
i
<
topk_length
;
++
i
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
expert_offsets
[
expert_id
+
1
]
=
occurrences
;
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tot_offset
+=
ceildiv
(
expert_offsets
[
i
+
1
],
block_size
)
*
block_size
;
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__syncthreads
();
}
#else
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
64
,
128
},
// Reduce both 2X
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
64
,
64
,
128
},
// Reduce N 4X, same K
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_k
=
th_config
.
thread_k
;
// Get max scale groups per thread-block
int
tb_groups
;
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
ceildiv
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
ceildiv
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
4
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
STAGES
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
ceildiv
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
STAGES
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if
(
th_config
.
thread_k
!=
128
&&
th_config
.
thread_k
!=
64
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
void
*
zp
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
int
num_bits
=
q_type
.
size_bits
();
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
USER_THREADS
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
int
tot_m
=
prob_m
;
const
int
*
topk_ids_ptr
=
(
const
int
*
)
topk_ids
;
int
*
expert_offsets_ptr
=
(
int
*
)
expert_offsets
;
compute_expert_offsets
<<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids_ptr
,
expert_offsets_ptr
,
tot_m
*
topk
,
moe_block_size
);
bool
do_permute_a
=
has_act_order
;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
int
pack_factor
=
32
/
q_type
.
size_bits
();
for
(
int
expert_idx
=
0
;
expert_idx
<
num_experts
;
++
expert_idx
)
{
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
+
(
prob_n
*
prob_k
/
(
pack_factor
*
4
))
*
expert_idx
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
num_groups
*
prob_n
/
8
*
expert_idx
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
+
num_groups
*
prob_n
/
(
pack_factor
*
4
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
do_permute_a
)
{
// Permute A columns
int
topk_rows
=
replicate_input
?
tot_m
:
tot_m
*
topk
;
int
block_rows
=
ceildiv
(
topk_rows
,
blocks
);
permute_cols_kernel
<<<
blocks
,
num_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
topk_rows
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
int
tot_m_blocks
=
ceildiv
(
tot_m
,
16
);
for
(
int
m_block
=
0
;
m_block
<
tot_m_blocks
;
m_block
+=
4
*
exec_cfg
.
max_m_blocks
)
{
if
(
false
)
{
}
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
}
}
}
// namespace marlin_moe
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
->
str
());
}
else
{
TORCH_CHECK
(
*
b_q_type
==
vllm
::
kU4B8
||
*
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
->
str
());
}
int
pack_factor
=
32
/
b_q_type
->
size_bits
();
int
max_par
=
4
;
int
dev
=
a
.
get_device
();
auto
options_dtype
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
zeros
({
size_m
,
topk
,
size_n
},
options_dtype
);
torch
::
Tensor
a_tmp
=
replicate_input
?
torch
::
zeros
({
size_m
,
size_k
},
options_dtype
)
:
torch
::
zeros
({
size_m
,
topk
,
size_k
},
options_dtype
);
torch
::
Tensor
expert_offsets
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
1
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
3
,
"b_scales rank = "
,
b_rank
,
" is not 3"
);
TORCH_CHECK
(
b_scales
.
size
(
2
)
==
size_n
,
"b_scales dim 2 = "
,
b_scales
.
size
(
2
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
TORCH_CHECK
(
VLLM_IMPLIES
(
!
is_k_full
,
has_act_order
),
"if is_k_full is false, has_act_order must be true"
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
3
,
"b_zeros rank = "
,
rank
,
" is not 3"
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
num_groups
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
2
)
==
size_n
/
pack_factor
,
"b_zeros dim 2 = "
,
b_zeros
.
size
(
2
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
marlin_moe
::
marlin_mm_moe
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
*
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/torch_bindings.cpp
View file @
ad385667
...
@@ -7,6 +7,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -7,6 +7,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
#ifndef USE_ROCM
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
// conditionally compiled so impl registration is in source file
#endif
}
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment