Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
0e1300f7
Commit
0e1300f7
authored
Jan 26, 2026
by
zhanghj2
Browse files
适配v32的decode kernel
parent
7abe5160
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
858 additions
and
68 deletions
+858
-68
csrc/defines.h
csrc/defines.h
+1
-1
csrc/sm90/decode/sparse_fp8/config.h
csrc/sm90/decode/sparse_fp8/config.h
+86
-65
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+500
-0
csrc/utils.h
csrc/utils.h
+271
-2
No files found.
csrc/defines.h
View file @
0e1300f7
...
...
@@ -4,7 +4,7 @@
// #include <cutlass/arch/barrier.h>
using
bf16
=
cutlass
::
bfloat16_t
;
using
fp8
=
cutlass
::
float_e4m3_t
;
using
fp8
=
unsigned
char
;
// using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
// using cutlass::arch::fence_view_async_shared;
// using cutlass::arch::fence_barrier_init;
...
...
csrc/sm90/decode/sparse_fp8/config.h
View file @
0e1300f7
...
...
@@ -16,10 +16,11 @@ template<ModelType MODEL_TYPE, int NUM_HEADS>
class
KernelTemplate
{
public:
static_assert
(
NUM_HEADS
==
64
||
NUM_HEADS
==
128
);
static
constexpr
int
NUM_M_BLOCKS
=
NUM_HEADS
/
64
;
static
constexpr
int
CLUSTER_SIZE
=
NUM_M_BLOCKS
;
static_assert
(
NUM_HEADS
==
64
||
NUM_HEADS
==
128
||
NUM_HEADS
==
16
);
// todo only support tp8
static
constexpr
int
BLOCK_M
=
16
;
static
constexpr
int
NUM_M_BLOCKS
=
NUM_HEADS
/
BLOCK_M
;
static
constexpr
bool
Is_causal
=
false
;
static
constexpr
int
HEAD_DIM_K
=
MODEL_TYPE
==
ModelType
::
V32
?
576
:
512
;
static
constexpr
int
HEAD_DIM_V
=
512
;
static
constexpr
int
HEAD_DIM_ROPE
=
64
;
...
...
@@ -28,67 +29,88 @@ static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;
static
constexpr
int
QUANT_TILE_SIZE
=
MODEL_TYPE
==
ModelType
::
V32
?
128
:
64
;
static
constexpr
int
NUM_SCALES
=
MODEL_TYPE
==
ModelType
::
V32
?
4
:
8
;
// For MODEL1: 7 fp8_e4m3 + 1 padding
static
constexpr
int
NUM_THREADS
=
128
*
3
;
static
constexpr
int
BLOCK_M
=
64
;
static
constexpr
int
NUM_THREADS
=
256
;
static
constexpr
int
TOPK_BLOCK_SIZE
=
64
;
static
constexpr
int
NUM_K_BUFS
=
2
;
using
SmemLayoutQTile
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_SW128_Atom
<
bf16
,
GMMA
::
Major
::
K
>
{},
Shape
<
Int
<
BLOCK_M
>
,
Int
<
64
>>
{}
));
template
<
int
NUM_TILES
>
using
SmemLayoutQTiles
=
decltype
(
tile_to_shape
(
SmemLayoutQTile
{},
Shape
<
Int
<
BLOCK_M
>
,
Int
<
64
*
NUM_TILES
>>
{},
Step
<
_1
,
_2
>
{}
));
using
SmemLayoutQ
=
SmemLayoutQTiles
<
HEAD_DIM_K
/
64
>
;
using
SmemLayoutKTile
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_INTER_Atom
<
bf16
,
GMMA
::
Major
::
K
>
{},
Shape
<
Int
<
TOPK_BLOCK_SIZE
>
,
_64
>
{},
Step
<
_1
,
_2
>
{}
));
template
<
int
NUM_TILES
>
using
SmemLayoutKTiles
=
decltype
(
tile_to_shape
(
SmemLayoutKTile
{},
Shape
<
Int
<
TOPK_BLOCK_SIZE
>
,
Int
<
64
*
NUM_TILES
>>
{},
Step
<
_1
,
_2
>
{}
));
template
<
int
NUM_TILES
>
using
SmemLayoutKTilesTransposed
=
decltype
(
composition
(
SmemLayoutKTiles
<
NUM_TILES
>
{},
Layout
<
Shape
<
Int
<
64
*
NUM_TILES
>
,
Int
<
TOPK_BLOCK_SIZE
>>
,
Stride
<
Int
<
TOPK_BLOCK_SIZE
>
,
_1
>>
{}
));
static
constexpr
int
OBUF_SW
=
64
;
using
SmemLayoutOBufAtom
=
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
;
using
SmemLayoutOBuf
=
decltype
(
tile_to_shape
(
SmemLayoutOBufAtom
{},
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{},
Step
<
_1
,
_2
>
{}
));
using
SmemLayoutOAccumBuf
=
Layout
<
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
,
Stride
<
Int
<
520
>
,
_1
>
// We use stride = 520 here to avoid bank conflict
using
elem_type
=
cutlass
::
bfloat16_t
;
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x64_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x16x64_F32BF16BF16F32_NT
>
>
;
using
SmemLayoutK
=
SmemLayoutKTiles
<
HEAD_DIM_K
/
64
>
;
using
SmemLayoutV
=
SmemLayoutKTilesTransposed
<
HEAD_DIM_V
/
64
>
;
using
SmemLayoutHalfV
=
SmemLayoutKTilesTransposed
<
HEAD_DIM_V
/
64
/
2
>
;
using
SmemLayoutS
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
Shape
<
Int
<
BLOCK_M
>
,
Int
<
TOPK_BLOCK_SIZE
>>
{}
));
static
constexpr
int
kNWarps
=
4
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
using
TiledMma
=
TiledMMA
<
MMA_Atom_Arch
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
MMA_Atom_Arch_16_16_32
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NN
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NN
>
>
;
using
TiledMma_16_16_32
=
TiledMMA
<
MMA_Atom_Arch_16_16_32
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
MMA_Atom_Arch_16x32_NT
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
TiledMma_O
=
TiledMMA
<
MMA_Atom_Arch_16x32_NT
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
SmemLayoutAtomK
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
{}));
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
TOPK_BLOCK_SIZE
>
,
Int
<
8
*
32
>>
{}));
using
SmemLayoutAtomV
=
SmemLayoutAtomK
;
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
TOPK_BLOCK_SIZE
>
,
Int
<
512
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
512
>
,
Int
<
TOPK_BLOCK_SIZE
>>
{},
GenRowMajor
{})));
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
4
*
16
*
16
>>
,
Stride
<
Int
<
1
>>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
4
*
16
*
16
>>
{}));
using
SmemLayoutRow
=
Layout
<
Shape
<
_128
>
,
Stride
<
_1
>>
;
using
Element
=
cutlass
::
bfloat16_t
;
using
ElementAccum
=
float
;
struct
SharedMemoryPlan
{
union
{
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutV
>>
smem_v
;
};
struct
{
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV_tmp>> smem_v_tmp; // Double buffer
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_sum
;
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
};
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
};
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
...
...
@@ -131,9 +153,8 @@ struct SharedMemoryPlan {
static
__device__
__forceinline__
void
compute_attn_1rowblock_splitkv_sparse_mla_fp8
(
const
SparseAttnDecodeParams
&
params
,
const
DecodingSchedMeta
&
sched_meta
,
int
batch_idx
);
static
__device__
__forceinline__
void
devfunc
(
const
SparseAttnDecodeParams
&
params
);
...
...
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
0e1300f7
This diff is collapsed.
Click to expand it.
csrc/utils.h
View file @
0e1300f7
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cstdint>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cute/tensor.hpp>
#include "defines.h"
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
...
...
@@ -80,3 +87,265 @@ struct RingBufferState {
return
new_state
;
}
};
namespace
flash
{
using
namespace
cute
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
template
<
>
struct
MaxOp
<
float
>
{
// This is slightly faster
__device__
__forceinline__
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
__forceinline__
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
64
||
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
||
THREADS
==
2
);
template
<
typename
T
,
typename
Operator
>
static
__device__
__forceinline__
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor
(
x
,
OFFSET
,
64
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
1
>
{
// static_assert(THREADS == 64 || THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4 || THREADS == 2);
template
<
typename
T
,
typename
Operator
>
static
__device__
__forceinline__
T
run
(
T
x
,
Operator
&
op
)
{
return
x
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
32
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
__forceinline__
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor
(
x
,
16
,
64
));
return
x
;
}
};
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
__forceinline__
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
int
begin_k
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
template
<
int
row
,
int
col
,
int
r_row
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
void
__ds_read_m32x16_row_col_rrow
(
Tensor0
&
src
,
Tensor1
&
dst
)
{
auto
lds
=
reinterpret_cast
<
__fp16
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
2
;
auto
d
=
__builtin_amdgcn_ds_read_m32x16f16
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
lds
),
offset
);
uint16_t
*
d_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
d
);
uint16_t
*
dst_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
(
dst
(
0
,
r_row
,
col
)));
dst_ptr
[
0
]
=
d_ptr
[
0
];
dst_ptr
[
1
]
=
d_ptr
[
1
];
dst_ptr
[
2
]
=
d_ptr
[
2
];
dst_ptr
[
3
]
=
d_ptr
[
3
];
dst_ptr
[
4
]
=
d_ptr
[
4
];
dst_ptr
[
5
]
=
d_ptr
[
5
];
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
void
__ds_read_m32x16_row_col
(
Tensor0
&
src
,
Tensor1
&
dst
)
{
auto
lds
=
reinterpret_cast
<
__fp16
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
2
;
auto
d
=
__builtin_amdgcn_ds_read_m32x16f16
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
lds
),
offset
);
uint16_t
*
d_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
d
);
uint16_t
*
dst_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
(
dst
(
0
,
row
,
col
)));
dst_ptr
[
0
]
=
d_ptr
[
0
];
dst_ptr
[
1
]
=
d_ptr
[
1
];
dst_ptr
[
2
]
=
d_ptr
[
2
];
dst_ptr
[
3
]
=
d_ptr
[
3
];
dst_ptr
[
4
]
=
d_ptr
[
4
];
dst_ptr
[
5
]
=
d_ptr
[
5
];
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
inline
__device__
float
fp8e4m3_to_fp32
(
const
fp8
&
input
)
{
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
const
uint32_t
sign
=
w
&
UINT32_C
(
0x80000000
);
const
uint32_t
nonsign
=
w
&
UINT32_C
(
0x7FFFFFFF
);
uint32_t
renorm_shift
=
__clz
(
nonsign
);
renorm_shift
=
renorm_shift
>
4
?
renorm_shift
-
4
:
0
;
uint32_t
result
=
sign
|
((
nonsign
<<
renorm_shift
>>
4
)
+
((
0x78
-
renorm_shift
)
<<
23
));
union
{
uint32_t
as_bits
;
float
as_value
;
}
fp32
=
{
result
};
return
fp32
.
as_value
;
}
template
<
typename
Layout
>
__forceinline__
__device__
auto
convert_layout_acc_rowcol
(
Layout
acc_layout
)
{
// static_assert(decltype(size<0>(acc_layout))::value == 4 || decltype(size<0>(acc_layout))::value == 8);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_1
>
{});
// (_4,_1,_2):(_1,_0,_4) -> ((_1,_4),_1,_2):((_0,_1),_0,_4)
return
make_layout
(
make_layout
(
get
<
1
>
(
l
)),
make_layout
(
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
2
>
(
l
)));
// (1, (4, 2)):((_0),(_1,_4))
};
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
if
constexpr
(
std
::
is_same_v
<
To_type
,
From_type
>
)
{
return
tensor
;
}
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
Tensor
tensor_To_type
=
make_tensor
<
To_type
>
(
layout
(
tensor
));
cutlass
::
Array
<
To_type
,
numel
>
*
result_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
To_type
,
numel
>
*>
(
tensor_To_type
.
data
());
#if defined(__gfx938__)
{
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
float_e4m3_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
return
tensor_To_type
;
}
#else
{
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_toward_zero
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
return
tensor_To_type
;
}
#endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
// auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
// return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template
<
class
TiledMma
,
class
TiledMma_O
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__forceinline__
__device__
auto
convert_layout_acc_Aregs
(
const
TiledMma
&
tiled_mma
,
const
TiledMma_O
&
tiled_mma_o
,
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
const
&
sAcc
)
{
using
Value_type
=
typename
Engine0
::
value_type
;
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
0
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
1
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
1
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
2
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
2
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
3
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
3
,
0
,
0
);
__syncthreads
();
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
16
>
,
Int
<
64
>>
{}));
Tensor
sP_tmp
=
make_tensor
(
sAcc
.
data
(),
SmemLayoutP
{});
auto
thr_mma
=
tiled_mma_o
.
get_thread_slice
(
tid
);
Tensor
tSrACC
=
thr_mma
.
partition_fragment_A
(
sP_tmp
);
tSrACC
(
0
,
0
,
0
)
=
sAcc
(
tid
*
8
+
0
);
tSrACC
(
1
,
0
,
0
)
=
sAcc
(
tid
*
8
+
1
);
tSrACC
(
2
,
0
,
0
)
=
sAcc
(
tid
*
8
+
2
);
tSrACC
(
3
,
0
,
0
)
=
sAcc
(
tid
*
8
+
3
);
tSrACC
(
0
,
0
,
1
)
=
sAcc
(
tid
*
8
+
0
+
4
);
tSrACC
(
1
,
0
,
1
)
=
sAcc
(
tid
*
8
+
1
+
4
);
tSrACC
(
2
,
0
,
1
)
=
sAcc
(
tid
*
8
+
2
+
4
);
tSrACC
(
3
,
0
,
1
)
=
sAcc
(
tid
*
8
+
3
+
4
);
tSrACC
(
0
,
0
,
2
)
=
sAcc
(
tid
*
8
+
0
+
16
*
32
);
tSrACC
(
1
,
0
,
2
)
=
sAcc
(
tid
*
8
+
1
+
16
*
32
);
tSrACC
(
2
,
0
,
2
)
=
sAcc
(
tid
*
8
+
2
+
16
*
32
);
tSrACC
(
3
,
0
,
2
)
=
sAcc
(
tid
*
8
+
3
+
16
*
32
);
tSrACC
(
0
,
0
,
3
)
=
sAcc
(
tid
*
8
+
0
+
4
+
16
*
32
);
tSrACC
(
1
,
0
,
3
)
=
sAcc
(
tid
*
8
+
1
+
4
+
16
*
32
);
tSrACC
(
2
,
0
,
3
)
=
sAcc
(
tid
*
8
+
2
+
4
+
16
*
32
);
tSrACC
(
3
,
0
,
3
)
=
sAcc
(
tid
*
8
+
3
+
4
+
16
*
32
);
return
tSrACC
;
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment