Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4525c5d7
Commit
4525c5d7
authored
Dec 02, 2024
by
coderfeli
Browse files
merge upstream
parents
a8d88d8d
44828b7c
Changes
308
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3676 additions
and
217 deletions
+3676
-217
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+112
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-5
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+21
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+30
-7
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+20
-6
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+99
-0
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+10
-0
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
...ile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
+615
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
+562
-0
include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp
include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp
+10
-0
include/ck_tile/ops/flatmm/block/uk/README.md
include/ck_tile/ops/flatmm/block/uk/README.md
+1
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+613
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
...tmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+516
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+508
-64
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+397
-57
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+65
-25
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+54
-39
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+13
-6
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+1
-1
No files found.
include/ck_tile/host/reference/reference_gemm.hpp
View file @
4525c5d7
...
...
@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
4525c5d7
...
...
@@ -8,6 +8,9 @@
namespace
ck_tile
{
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
...
...
@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
// allocate a temp buffer, and fill the value with [number_token|topk]
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std
::
vector
<
IndexType
>
(
unit_size
,
MOE_SORTING_MOCK_ID
(
num_token
,
topk
)));
#else
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
...
...
@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
i
]
=
MOE_SORTING_MOCK_ID
(
num_token
,
topk
);
#else
expert_tokens
[
e
][
i
]
=
num_token
;
#endif
expert_token_weights
[
e
][
i
]
=
0
;
}
}
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
idx
]
=
MOE_SORTING_MOCK_ID
(
t
,
k
);
#else
expert_tokens
[
e
][
idx
]
=
t
;
#endif
expert_token_weights
[
e
][
idx
]
=
w
;
expert_slice_idxs
[
e
]
++
;
}
...
...
@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
unit_cnt
*=
unit_size
;
return
;
}
#undef MOE_SORTING_MOCK_ID
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
4525c5d7
...
...
@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template
<
typename
DataType
>
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
perm
)
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
...
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
tmp
[
perm
[
i
]]
=
y_coord
[
i
];
}
return
tmp
;
}();
...
...
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
DataType
>
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
{
tmp
[
i
]
=
x_shape
[
perm
[
i
]];
}
return
tmp
;
}();
HostTensor
<
DataType
>
y
(
y_shape
);
reference_permute
(
x
,
y
,
perm
);
return
y
;
}
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
View file @
4525c5d7
...
...
@@ -30,6 +30,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -69,6 +70,13 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
...
...
@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
});
// compute absmax, each-thread->cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
auto
absmax
=
[
&
]()
{
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
}
else
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
}
}();
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
View file @
4525c5d7
...
...
@@ -30,6 +30,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -76,6 +77,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
...
...
@@ -177,6 +185,12 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
block_reduce2d
(
y
,
absmax
,
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
else
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSaveX
)
...
...
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
4525c5d7
...
...
@@ -572,6 +572,105 @@ struct FastGelu
}
};
struct
FastGeluAsm
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
CK_TILE_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
CK_TILE_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
uint32_t
c1
=
0xbd92220c
;
// -2.0 * 0.035677f;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
uint32_t
log2e_
=
0x3fb8aa3b
;
// log2e_v<float>;
float
tmp
;
asm
volatile
(
"v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x
\n
"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"s_nop 0 ; hazard for exp
\n
"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)
\n
"
"s_nop 0 ; hazard for rcp
\n
"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)
\n
"
:
[
v_y
]
"=v"
(
y
),
[
v_tmp
]
"+v"
(
tmp
)
:
[
v_x
]
"v"
(
x
),
[
s_c1
]
"s"
(
c1
),
[
v_c2
]
"v"
(
c2
),
[
s_log2e
]
"s"
(
log2e_
)
:
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u0
=
x
.
x
*
(
c1
*
x
.
x
*
x
.
x
+
c2
);
const
float
emu0
=
exp
(
u0
);
y
.
x
=
x
.
x
/
(
1.
f
+
emu0
);
const
float
u1
=
x
.
y
*
(
c1
*
x
.
y
*
x
.
y
+
c2
);
const
float
emu1
=
exp
(
u1
);
y
.
y
=
x
.
y
/
(
1.
f
+
emu1
);
}
// this is packed verion to remove data hazard for trans
template
<
>
CK_TILE_DEVICE
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
const
uint32_t
c1
=
0xbd92220c
;
// -2.0 * 0.035677f;
float
c2
=
-
2.0
*
0.797885
f
;
const
uint32_t
log2e_
=
0x3fb8aa3b
;
// log2e_v<float>;
float
tmp0
,
tmp1
;
float
y0
=
x
.
x
,
y1
=
x
.
y
;
asm
volatile
(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x
\n
"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x
\n
"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f
\n
"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)
\n
"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)
\n
"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)
\n
"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)
\n
"
:
[
v_y0
]
"+v"
(
y0
),
[
v_y1
]
"+v"
(
y1
),
[
v_c2
]
"+v"
(
c2
),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[
v_tmp0
]
"+v"
(
tmp0
),
[
v_tmp1
]
"+v"
(
tmp1
)
:
[
s_c1
]
"s"
(
c1
),
[
s_log2e
]
"s"
(
log2e_
)
:
);
y
.
x
=
y0
;
y
.
y
=
y1
;
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
...
...
include/ck_tile/ops/
moe_sorting
.hpp
→
include/ck_tile/ops/
flatmm
.hpp
View file @
4525c5d7
...
...
@@ -3,9 +3,8 @@
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
0 → 100644
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace
ck_tile
{
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct
Flatmm_32x512x128_1x4x1_16x16x32_Base
// for f16/bf16
{
static
constexpr
index_t
Block_M
=
32
;
static
constexpr
index_t
Block_N
=
512
;
static
constexpr
index_t
Block_K
=
128
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
NumWarps
=
4
;
static
constexpr
index_t
Warp_M
=
16
;
static
constexpr
index_t
Warp_N
=
16
;
static
constexpr
index_t
Warp_K
=
32
;
// 16 * SubKPacks
static
constexpr
index_t
BlockSize
=
256
;
static
constexpr
index_t
SubKPacks
=
2
;
// this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
// 4
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 8
static
constexpr
index_t
Repeat_K
=
Block_K
/
(
Warp_K
*
WarpPerBlock_K
);
// 8/2=4
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockDist
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_N
,
WarpPerBlock_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
{
using
CDataType
=
float
;
constexpr
auto
c_block_dstr
=
MakeCBlockDist
();
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KVector
=
2
;
// GetAlignment_A<Problem>(); // async copy 1 dword
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// load from LDS to register, every wave has same layout
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
kKIter
=
2
;
static_assert
(
KPack_
==
(
kABKPerLane
*
kKIter
));
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m0 y
number
<
kAMLane
>
{},
// m1 p
number
<
Repeat_K
>
{},
// k0 y
number
<
kABKLane
>
{},
// k1 p
number
<
KPack_
>
{}),
// k2 y-vector
make_tuple
(
number
<
kAMLane
*
(
Block_K
+
KPad
)
>
{},
// m0
number
<
Block_K
+
KPad
>
{},
// m1
number
<
kABKLane
*
KPack_
>
{},
// k0
number
<
KPack_
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack_
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_K
>
{},
number
<
kABKLane
>
{},
number
<
KPack_
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
static
constexpr
auto
GetGemm_AWarpEnc
()
{
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
kKIter
=
2
;
using
enc_
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
kAMLane
>
,
sequence
<
kABKLane
,
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
return
enc_
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
32
*
(
128
+
8
)
*
sizeof
(
bf16_t
);
}
};
struct
Flatmm_32x512x128_1x4x1_16x16x32_BF16
:
public
Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsStoreDesc_A
()),
MakeLdsStoreDesc_A
().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sld
=
[
&
]()
{
constexpr
auto
a_warp_enc_
=
GetGemm_AWarpEnc
();
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
WarpPerBlock_N
>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_K
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
a_warp_enc_
);
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsLoadDesc_A
()),
MakeLdsLoadDesc_A
().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
const
index_t
tile_offset_a_bytes
=
tile_offset_a
*
sizeof
(
ADataType
);
const
index_t
tile_offset_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
constexpr
auto
sld_os
=
generate_tuple
(
[
&
](
auto
i_access
)
{
return
number
<
a_sld
.
get_bottom_linear_offset
(
i_access
)
*
sizeof
(
ADataType
)
>
{};
},
number
<
a_sld
.
get_num_of_access
()
>
{});
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
};
struct
Flatmm_32x512x128_1x4x1_16x16x32_FP16
:
public
Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsStoreDesc_A
()),
MakeLdsStoreDesc_A
().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sld
=
[
&
]()
{
constexpr
auto
a_warp_enc_
=
GetGemm_AWarpEnc
();
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
WarpPerBlock_N
>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_K
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
a_warp_enc_
);
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsLoadDesc_A
()),
MakeLdsLoadDesc_A
().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
const
index_t
tile_offset_a_bytes
=
tile_offset_a
*
sizeof
(
ADataType
);
const
index_t
tile_offset_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
constexpr
auto
sld_os
=
generate_tuple
(
[
&
](
auto
i_access
)
{
return
number
<
a_sld
.
get_bottom_linear_offset
(
i_access
)
*
sizeof
(
ADataType
)
>
{};
},
number
<
a_sld
.
get_num_of_access
()
>
{});
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
0 → 100644
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace
ck_tile
{
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct
FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
static
constexpr
index_t
Block_M
=
32
;
static
constexpr
index_t
Block_N
=
128
;
static
constexpr
index_t
Block_K
=
512
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
Warp_M
=
16
;
static
constexpr
index_t
Warp_N
=
16
;
static
constexpr
index_t
Warp_K
=
32
;
static
constexpr
index_t
BlockSize
=
256
;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
// 4
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 2
static
constexpr
index_t
Repeat_K
=
Block_K
/
(
Warp_K
*
WarpPerBlock_K
);
// 16
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockDist
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_N
,
WarpPerBlock_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return
2
*
2
*
4
*
4
*
(
16
*
4
+
4
)
*
sizeof
(
bf16_t
);
}
};
struct
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16
:
public
FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
BRes
,
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_o
)
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
float
s1
=
scale_
[
number
<
1
>
{}];
index_t
loop_cnt
=
n
/
Block_N
;
register
float
v_c0
asm
(
"v64"
);
register
float
v_c1
asm
(
"v65"
);
register
float
v_c2
asm
(
"v66"
);
register
float
v_c3
asm
(
"v67"
);
register
float
v_c4
asm
(
"v68"
);
register
float
v_c5
asm
(
"v69"
);
register
float
v_c6
asm
(
"v70"
);
register
float
v_c7
asm
(
"v71"
);
register
float
v_c8
asm
(
"v72"
);
register
float
v_c9
asm
(
"v73"
);
register
float
v_c10
asm
(
"v74"
);
register
float
v_c11
asm
(
"v75"
);
register
float
v_c12
asm
(
"v76"
);
register
float
v_c13
asm
(
"v77"
);
register
float
v_c14
asm
(
"v78"
);
register
float
v_c15
asm
(
"v79"
);
register
float
v_c16
asm
(
"v80"
);
register
float
v_c17
asm
(
"v81"
);
register
float
v_c18
asm
(
"v82"
);
register
float
v_c19
asm
(
"v83"
);
register
float
v_c20
asm
(
"v84"
);
register
float
v_c21
asm
(
"v85"
);
register
float
v_c22
asm
(
"v86"
);
register
float
v_c23
asm
(
"v87"
);
register
float
v_c24
asm
(
"v88"
);
register
float
v_c25
asm
(
"v89"
);
register
float
v_c26
asm
(
"v90"
);
register
float
v_c27
asm
(
"v91"
);
register
float
v_c28
asm
(
"v92"
);
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int
lane_id
=
threadIdx
.
x
%
64
;
int
sld_y_os
=
(
lane_id
%
16
)
*
4
+
(
lane_id
/
16
)
*
128
;
sld_y_os
*=
2
;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int
sfl_sst
=
(
threadIdx
.
x
%
16
*
4
)
+
(
threadIdx
.
x
/
16
)
*
(
64
+
4
);
sfl_sst
*=
2
;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int
sfl_sld
=
(
lane_id
%
2
)
*
2
+
(
lane_id
/
2
)
*
(
64
+
4
)
+
(
threadIdx
.
x
/
64
)
*
4
;
sfl_sld
*=
2
;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
c0
]
"+v"
(
v_c0
),
[
c1
]
"+v"
(
v_c1
),
[
c2
]
"+v"
(
v_c2
),
[
c3
]
"+v"
(
v_c3
),
[
c4
]
"+v"
(
v_c4
),
[
c5
]
"+v"
(
v_c5
),
[
c6
]
"+v"
(
v_c6
),
[
c7
]
"+v"
(
v_c7
),
[
c8
]
"+v"
(
v_c8
),
[
c9
]
"+v"
(
v_c9
),
[
c10
]
"+v"
(
v_c10
),
[
c11
]
"+v"
(
v_c11
),
[
c12
]
"+v"
(
v_c12
),
[
c13
]
"+v"
(
v_c13
),
[
c14
]
"+v"
(
v_c14
),
[
c15
]
"+v"
(
v_c15
),
[
c16
]
"+v"
(
v_c16
),
[
c17
]
"+v"
(
v_c17
),
[
c18
]
"+v"
(
v_c18
),
[
c19
]
"+v"
(
v_c19
),
[
c20
]
"+v"
(
v_c20
),
[
c21
]
"+v"
(
v_c21
),
[
c22
]
"+v"
(
v_c22
),
[
c23
]
"+v"
(
v_c23
),
[
c24
]
"+v"
(
v_c24
),
[
c25
]
"+v"
(
v_c25
),
[
c26
]
"+v"
(
v_c26
),
[
c27
]
"+v"
(
v_c27
),
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
3
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
4
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
5
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
6
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
7
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s8"
,
"s9"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s38"
,
"s39"
,
"s52"
,
"s86"
,
"s36"
,
"s37"
,
"v50"
,
"v54"
,
"v55"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
"v148"
,
"v149"
,
"v150"
,
"v151"
,
"v152"
,
"v153"
,
"v154"
,
"v155"
,
"v156"
,
"v157"
,
"v158"
,
"v159"
,
"v160"
,
"v161"
,
"v162"
,
"v163"
,
"v164"
,
"v165"
,
"v166"
,
"v167"
,
"v168"
,
"v169"
,
"v170"
,
"v171"
,
"v172"
,
"v173"
,
"v174"
,
"v175"
,
"v176"
,
"v177"
,
"v178"
,
"v179"
,
"v180"
,
"v181"
,
"v182"
,
"v183"
,
"v184"
,
"v185"
,
"v186"
,
"v187"
,
"v188"
,
"v189"
,
"v190"
,
"v191"
,
"v192"
,
"v193"
,
"v194"
,
"v195"
,
"v196"
,
"v197"
,
"v198"
,
"v199"
,
"v200"
,
"v201"
,
"v202"
,
"v203"
,
"v204"
,
"v205"
,
"v206"
,
"v207"
,
"v208"
,
"v209"
,
"v210"
,
"v211"
,
"v212"
,
"v213"
,
"v214"
,
"v215"
,
"v216"
,
"v217"
,
"v218"
,
"v219"
,
"v220"
,
"v221"
,
"v222"
,
"v223"
,
"v224"
,
"v225"
,
"v226"
,
"v227"
,
"v228"
,
"v229"
,
"v230"
,
"v231"
,
"v232"
,
"v233"
,
"v234"
,
"v235"
,
"v236"
,
"v237"
,
"v238"
,
"v239"
,
"v240"
,
"v241"
,
"v242"
,
"v243"
,
"v244"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16
:
public
FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
BRes
,
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_o
)
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
float
s1
=
scale_
[
number
<
1
>
{}];
index_t
loop_cnt
=
n
/
Block_N
;
register
float
v_c0
asm
(
"v64"
);
register
float
v_c1
asm
(
"v65"
);
register
float
v_c2
asm
(
"v66"
);
register
float
v_c3
asm
(
"v67"
);
register
float
v_c4
asm
(
"v68"
);
register
float
v_c5
asm
(
"v69"
);
register
float
v_c6
asm
(
"v70"
);
register
float
v_c7
asm
(
"v71"
);
register
float
v_c8
asm
(
"v72"
);
register
float
v_c9
asm
(
"v73"
);
register
float
v_c10
asm
(
"v74"
);
register
float
v_c11
asm
(
"v75"
);
register
float
v_c12
asm
(
"v76"
);
register
float
v_c13
asm
(
"v77"
);
register
float
v_c14
asm
(
"v78"
);
register
float
v_c15
asm
(
"v79"
);
register
float
v_c16
asm
(
"v80"
);
register
float
v_c17
asm
(
"v81"
);
register
float
v_c18
asm
(
"v82"
);
register
float
v_c19
asm
(
"v83"
);
register
float
v_c20
asm
(
"v84"
);
register
float
v_c21
asm
(
"v85"
);
register
float
v_c22
asm
(
"v86"
);
register
float
v_c23
asm
(
"v87"
);
register
float
v_c24
asm
(
"v88"
);
register
float
v_c25
asm
(
"v89"
);
register
float
v_c26
asm
(
"v90"
);
register
float
v_c27
asm
(
"v91"
);
register
float
v_c28
asm
(
"v92"
);
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int
lane_id
=
threadIdx
.
x
%
64
;
int
sld_y_os
=
(
lane_id
%
16
)
*
4
+
(
lane_id
/
16
)
*
128
;
sld_y_os
*=
2
;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int
sfl_sst
=
(
threadIdx
.
x
%
16
*
4
)
+
(
threadIdx
.
x
/
16
)
*
(
64
+
4
);
sfl_sst
*=
2
;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int
sfl_sld
=
(
lane_id
%
2
)
*
2
+
(
lane_id
/
2
)
*
(
64
+
4
)
+
(
threadIdx
.
x
/
64
)
*
4
;
sfl_sld
*=
2
;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
c0
]
"+v"
(
v_c0
),
[
c1
]
"+v"
(
v_c1
),
[
c2
]
"+v"
(
v_c2
),
[
c3
]
"+v"
(
v_c3
),
[
c4
]
"+v"
(
v_c4
),
[
c5
]
"+v"
(
v_c5
),
[
c6
]
"+v"
(
v_c6
),
[
c7
]
"+v"
(
v_c7
),
[
c8
]
"+v"
(
v_c8
),
[
c9
]
"+v"
(
v_c9
),
[
c10
]
"+v"
(
v_c10
),
[
c11
]
"+v"
(
v_c11
),
[
c12
]
"+v"
(
v_c12
),
[
c13
]
"+v"
(
v_c13
),
[
c14
]
"+v"
(
v_c14
),
[
c15
]
"+v"
(
v_c15
),
[
c16
]
"+v"
(
v_c16
),
[
c17
]
"+v"
(
v_c17
),
[
c18
]
"+v"
(
v_c18
),
[
c19
]
"+v"
(
v_c19
),
[
c20
]
"+v"
(
v_c20
),
[
c21
]
"+v"
(
v_c21
),
[
c22
]
"+v"
(
v_c22
),
[
c23
]
"+v"
(
v_c23
),
[
c24
]
"+v"
(
v_c24
),
[
c25
]
"+v"
(
v_c25
),
[
c26
]
"+v"
(
v_c26
),
[
c27
]
"+v"
(
v_c27
),
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
3
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
4
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
5
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
6
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
7
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s8"
,
"s9"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s38"
,
"s39"
,
"s52"
,
"s86"
,
"s36"
,
"s37"
,
"v50"
,
"v54"
,
"v55"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
"v148"
,
"v149"
,
"v150"
,
"v151"
,
"v152"
,
"v153"
,
"v154"
,
"v155"
,
"v156"
,
"v157"
,
"v158"
,
"v159"
,
"v160"
,
"v161"
,
"v162"
,
"v163"
,
"v164"
,
"v165"
,
"v166"
,
"v167"
,
"v168"
,
"v169"
,
"v170"
,
"v171"
,
"v172"
,
"v173"
,
"v174"
,
"v175"
,
"v176"
,
"v177"
,
"v178"
,
"v179"
,
"v180"
,
"v181"
,
"v182"
,
"v183"
,
"v184"
,
"v185"
,
"v186"
,
"v187"
,
"v188"
,
"v189"
,
"v190"
,
"v191"
,
"v192"
,
"v193"
,
"v194"
,
"v195"
,
"v196"
,
"v197"
,
"v198"
,
"v199"
,
"v200"
,
"v201"
,
"v202"
,
"v203"
,
"v204"
,
"v205"
,
"v206"
,
"v207"
,
"v208"
,
"v209"
,
"v210"
,
"v211"
,
"v212"
,
"v213"
,
"v214"
,
"v215"
,
"v216"
,
"v217"
,
"v218"
,
"v219"
,
"v220"
,
"v221"
,
"v222"
,
"v223"
,
"v224"
,
"v225"
,
"v226"
,
"v227"
,
"v228"
,
"v229"
,
"v230"
,
"v231"
,
"v232"
,
"v233"
,
"v234"
,
"v235"
,
"v236"
,
"v237"
,
"v238"
,
"v239"
,
"v240"
,
"v241"
,
"v242"
,
"v243"
,
"v244"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp
0 → 100644
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
include/ck_tile/ops/flatmm/block/uk/README.md
0 → 100644
View file @
4525c5d7
the files under this folder should not be included directly!
\ No newline at end of file
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
0 → 100644
View file @
4525c5d7
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
# define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], "
x0_
", "
x0_
"
\n
"
\
" v_add3_u32 v50, "
x0_
", %[v_nan_lo], 1
\n
"
\
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37]
\n
"
\
" v_cmp_u_f32 s[36:37], "
x1_
", "
x1_
"
\n
"
\
" v_add3_u32 v50, "
x1_
", %[v_nan_lo], 1
\n
"
\
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37]
\n
"
\
" v_perm_b32 "
y_
", v55, v54, s52
\n
"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, "
x0_
"
\n
"
\
" v_cvt_f16_f32 v55, "
x1_
"
\n
"
\
" v_pack_b32_f16 "
y_
", v54, v55
\n
"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------
\n
"
" s_mov_b32 s52, 0x07060302 ; v_perm
\n
"
" s_mov_b64 s[38:39], exec ; save current exec
\n
"
" s_mov_b32 s8, %[s_res_o0]
\n
"
" s_mov_b32 s9, %[s_res_o1]
\n
"
" s_mov_b32 s12, %[s_res_b0]
\n
"
" s_mov_b32 s13, %[s_res_b1]
\n
"
" s_mov_b32 s14, %[s_res_b2]
\n
"
" s_mov_b32 s15, %[s_res_b3]
\n
"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base]
\n
"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base]
\n
"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base]
\n
"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base]
\n
"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base]
\n
"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base]
\n
"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base]
\n
"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base]
\n
"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base]
\n
"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base]
\n
"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base]
\n
"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base]
\n
"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base]
\n
"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base]
\n
"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base]
\n
"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base]
\n
"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base]
\n
"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base]
\n
"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base]
\n
"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base]
\n
"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base]
\n
"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base]
\n
"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base]
\n
"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base]
\n
"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base]
\n
"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base]
\n
"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base]
\n
"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base]
\n
"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base]
\n
"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base]
\n
"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base]
\n
"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base]
\n
"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base]
\n
"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base]
\n
"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base]
\n
"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base]
\n
"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base]
\n
"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base]
\n
"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base]
\n
"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base]
\n
"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base]
\n
"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base]
\n
"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base]
\n
"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base]
\n
"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base]
\n
"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base]
\n
"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base]
\n
"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base]
\n
"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base]
\n
"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base]
\n
"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base]
\n
"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base]
\n
"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base]
\n
"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base]
\n
"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base]
\n
"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base]
\n
"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base]
\n
"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base]
\n
"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base]
\n
"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base]
\n
"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base]
\n
"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base]
\n
"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base]
\n
"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base]
\n
"
" s_waitcnt 0
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen
\n
"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_waitcnt 0
\n
"
"L_start%=:
\n
"
" s_waitcnt vmcnt(32)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0
\n
"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0
\n
"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0
\n
"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0
\n
"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]]
\n
"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]]
\n
"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]]
\n
"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], %[c13], %[c14], %[c15]]
\n
"
_UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], %[c15]]
\n
"
" v_mul_f32 %[c0], %[scale_0], %[c0]
\n
"
" v_mul_f32 %[c1], %[scale_0], %[c1]
\n
"
" v_mul_f32 %[c2], %[scale_0], %[c2]
\n
"
" v_mul_f32 %[c3], %[scale_0], %[c3]
\n
"
" v_mul_f32 %[c4], %[scale_1], %[c4]
\n
"
" v_mul_f32 %[c5], %[scale_1], %[c5]
\n
"
" v_mul_f32 %[c6], %[scale_1], %[c6]
\n
"
" v_mul_f32 %[c7], %[scale_1], %[c7]
\n
"
" v_mul_f32 %[c8], %[scale_0], %[c8]
\n
"
" v_mul_f32 %[c9], %[scale_0], %[c9]
\n
"
" v_mul_f32 %[c10], %[scale_0], %[c10]
\n
"
" v_mul_f32 %[c11], %[scale_0], %[c11]
\n
"
" v_mul_f32 %[c12], %[scale_1], %[c12]
\n
"
" v_mul_f32 %[c13], %[scale_1], %[c13]
\n
"
" v_mul_f32 %[c14], %[scale_1], %[c14]
\n
"
" v_mul_f32 %[c15], %[scale_1], %[c15]
\n
"
_UK_PK_CVT_
(
"%[c0]"
,
"%[c1]"
,
"%[c0]"
)
_UK_PK_CVT_
(
"%[c2]"
,
"%[c3]"
,
"%[c1]"
)
_UK_PK_CVT_
(
"%[c4]"
,
"%[c5]"
,
"%[c2]"
)
_UK_PK_CVT_
(
"%[c6]"
,
"%[c7]"
,
"%[c3]"
)
_UK_PK_CVT_
(
"%[c8]"
,
"%[c9]"
,
"%[c4]"
)
_UK_PK_CVT_
(
"%[c10]"
,
"%[c11]"
,
"%[c5]"
)
_UK_PK_CVT_
(
"%[c12]"
,
"%[c13]"
,
"%[c6]"
)
_UK_PK_CVT_
(
"%[c14]"
,
"%[c15]"
,
"%[c7]"
)
" ;------------------------------
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_barrier
\n
"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base]
\n
"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base]
\n
"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base]
\n
"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base]
\n
"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base]
\n
"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base]
\n
"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base]
\n
"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_mov_b64 exec, %[s_execflag_0]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o0], %[c0], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_1]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o1], %[c1], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_2]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o2], %[c2], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_3]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o3], %[c3], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_4]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o4], %[c4], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_5]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o5], %[c5], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_6]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o6], %[c6], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o7], %[c7], s[8:9]
\n
"
" s_mov_b64 exec, s[38:39]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k--
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_add_u32 s8, %[s_tile_os_o], s8
\n
"
" s_addc_u32 s9, 0, s9
\n
"
" s_waitcnt vmcnt(32)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0
\n
"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0
\n
"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0
\n
"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]]
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]]
\n
"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]]
\n
"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]]
\n
"
_UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]]
\n
"
" v_mul_f32 %[c16], %[scale_0], %[c16]
\n
"
" v_mul_f32 %[c17], %[scale_0], %[c17]
\n
"
" v_mul_f32 %[c18], %[scale_0], %[c18]
\n
"
" v_mul_f32 %[c19], %[scale_0], %[c19]
\n
"
" v_mul_f32 %[c20], %[scale_1], %[c20]
\n
"
" v_mul_f32 %[c21], %[scale_1], %[c21]
\n
"
" v_mul_f32 %[c22], %[scale_1], %[c22]
\n
"
" v_mul_f32 %[c23], %[scale_1], %[c23]
\n
"
" v_mul_f32 %[c24], %[scale_0], %[c24]
\n
"
" v_mul_f32 %[c25], %[scale_0], %[c25]
\n
"
" v_mul_f32 %[c26], %[scale_0], %[c26]
\n
"
" v_mul_f32 %[c27], %[scale_0], %[c27]
\n
"
" v_mul_f32 %[c28], %[scale_1], %[c28]
\n
"
" v_mul_f32 %[c29], %[scale_1], %[c29]
\n
"
" v_mul_f32 %[c30], %[scale_1], %[c30]
\n
"
" v_mul_f32 %[c31], %[scale_1], %[c31]
\n
"
_UK_PK_CVT_
(
"%[c16]"
,
"%[c17]"
,
"%[c16]"
)
_UK_PK_CVT_
(
"%[c18]"
,
"%[c19]"
,
"%[c17]"
)
_UK_PK_CVT_
(
"%[c20]"
,
"%[c21]"
,
"%[c18]"
)
_UK_PK_CVT_
(
"%[c22]"
,
"%[c23]"
,
"%[c19]"
)
_UK_PK_CVT_
(
"%[c24]"
,
"%[c25]"
,
"%[c20]"
)
_UK_PK_CVT_
(
"%[c26]"
,
"%[c27]"
,
"%[c21]"
)
_UK_PK_CVT_
(
"%[c28]"
,
"%[c29]"
,
"%[c22]"
)
_UK_PK_CVT_
(
"%[c30]"
,
"%[c31]"
,
"%[c23]"
)
" ;------------------------------
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base]
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_barrier
\n
"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base]
\n
"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base]
\n
"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base]
\n
"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base]
\n
"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base]
\n
"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base]
\n
"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base]
\n
"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_mov_b64 exec, %[s_execflag_0]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o0], %[c16], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_1]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o1], %[c17], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_2]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o2], %[c18], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_3]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o3], %[c19], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_4]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o4], %[c20], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_5]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o5], %[c21], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_6]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o6], %[c22], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
_UK_ATOMIC_ADD_
" %[v_os_o7], %[c23], s[8:9]
\n
"
" s_mov_b64 exec, s[38:39]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k--
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_add_u32 s8, %[s_tile_os_o], s8
\n
"
" s_addc_u32 s9, 0, s9
\n
"
" s_branch L_start%=
\n
"
"L_end%=:
\n
"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
0 → 100644
View file @
4525c5d7
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#endif
"s_mov_b32 s16, %[s_res_a0]
\n
"
"s_mov_b32 s17, %[s_res_a1]
\n
"
"s_mov_b32 s18, %[s_res_a2]
\n
"
"s_mov_b32 s19, %[s_res_a3]
\n
"
"s_mov_b32 s20, %[s_res_b0]
\n
"
"s_mov_b32 s21, %[s_res_b1]
\n
"
"s_mov_b32 s22, %[s_res_b2]
\n
"
"s_mov_b32 s23, %[s_res_b3]
\n
"
// "s_nop 4\n"
"; -- prefetch A0
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[smem_sz], %[s_m0_init]
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond
\n
"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond
\n
"
"s_add_u32 s16, s86, s16 ; move a with cond
\n
"
"s_addc_u32 s17, 0, s17 ; move a with cond
\n
"
"; -- prefetch A1
\n
"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond
\n
"
"s_add_u32 s16, s86, s16 ; move a with cond
\n
"
"s_addc_u32 s17, 0, s17 ; move a with cond
\n
"
"; -- prefetch B0
\n
"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond
\n
"
"s_add_u32 s20, s86, s20 ; move b with cond
\n
"
"s_addc_u32 s21, 0, s21 ; move b with cond
\n
"
"s_waitcnt vmcnt(40)
\n
"
"s_barrier
\n
"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
// 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
"L_start%=:
\n
"
" s_waitcnt vmcnt(24) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0]
\n
"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0]
\n
"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0]
\n
"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0]
\n
"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1]
\n
"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1]
\n
"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1]
\n
"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1]
\n
"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2]
\n
"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2]
\n
"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2]
\n
"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2]
\n
"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3]
\n
"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3]
\n
"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3]
\n
"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3]
\n
"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[smem_sz], %[s_m0_init]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4]
\n
"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4]
\n
"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4]
\n
"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4]
\n
"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5]
\n
"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5]
\n
"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5]
\n
"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5]
\n
"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6]
\n
"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6]
\n
"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6]
\n
"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6]
\n
"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7]
\n
"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7]
\n
"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7]
\n
"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7]
\n
"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8]
\n
"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8]
\n
"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9]
\n
"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9]
\n
"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10]
\n
"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10]
\n
"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11]
\n
"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11]
\n
"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12]
\n
"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12]
\n
"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13]
\n
"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13]
\n
"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14]
\n
"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14]
\n
"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15]
\n
"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15]
\n
"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_a], 0
\n
"
" s_add_u32 s16, s86, s16
\n
"
" s_addc_u32 s17, 0, s17
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_b], 0
\n
"
" s_add_u32 s20, s86, s20
\n
"
" s_addc_u32 s21, 0, s21
\n
"
" ;------------------------------------------
\n
"
" s_waitcnt vmcnt(24) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0]
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0]
\n
"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0]
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0]
\n
"
_UK_MFMA_
" %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0]
\n
"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1]
\n
"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1]
\n
"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1]
\n
"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1]
\n
"
_UK_MFMA_
" %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1]
\n
"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2]
\n
"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2]
\n
"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2]
\n
"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2]
\n
"
_UK_MFMA_
" %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2]
\n
"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3]
\n
"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3]
\n
"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, %[s_size_per_issue], m0
\n
"
_UK_MFMA_
" %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3]
\n
"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3]
\n
"
_UK_MFMA_
" %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3]
\n
"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
" s_add_u32 m0, 0, %[s_m0_init]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4]
\n
"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4]
\n
"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4]
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4]
\n
"
_UK_MFMA_
" %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4]
\n
"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5]
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5]
\n
"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5]
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5]
\n
"
_UK_MFMA_
" %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5]
\n
"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6]
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6]
\n
"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6]
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6]
\n
"
_UK_MFMA_
" %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6]
\n
"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7]
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7]
\n
"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7]
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7]
\n
"
_UK_MFMA_
" %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7]
\n
"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8]
\n
"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8]
\n
"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9]
\n
"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9]
\n
"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10]
\n
"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10]
\n
"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11]
\n
"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11]
\n
"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11]
\n
"
_UK_MFMA_
" %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12]
\n
"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12]
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13]
\n
"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13]
\n
"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14]
\n
"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen
\n
"
_UK_MFMA_
" %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14]
\n
"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024
\n
"
_UK_MFMA_
" %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15]
\n
"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048
\n
"
_UK_MFMA_
" %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15]
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
_UK_MFMA_
" %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15]
\n
"
_UK_MFMA_
" %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_a], 0
\n
"
" s_add_u32 s16, s86, s16
\n
"
" s_addc_u32 s17, 0, s17
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_tile_os_b], 0
\n
"
" s_add_u32 s20, s86, s20
\n
"
" s_addc_u32 s21, 0, s21
\n
"
" s_branch L_start%=
\n
"
"L_end%=:
\n
"
" s_nop 2
\n
"
#undef _UK_MFMA_
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
4525c5d7
...
...
@@ -304,7 +304,7 @@ struct FmhaBwdDQDKDVKernel
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
MakeKargs
Impl
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
...
...
@@ -470,9 +470,251 @@ struct FmhaBwdDQDKDVKernel
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
rand_val_ptr
,
dk_ptr
,
dv_ptr
,
dbias_ptr
,
dq_acc_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
const
void
*
,
const
void
*>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
rand_val_ptr
,
dk_ptr
,
dv_ptr
,
dbias_ptr
,
dq_acc_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargsImpl
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
...
...
@@ -616,6 +858,208 @@ struct FmhaBwdDQDKDVKernel
return
kargs
;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
rand_val_ptr
,
dk_ptr
,
dv_ptr
,
dbias_ptr
,
dq_acc_ptr
,
seqstart_q_ptr
,
seqstart_k_ptr
,
seqlen_k_ptr
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dbias
,
split_stride_dq_acc
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
const
void
*
lse_ptr
,
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
const
void
*
,
const
void
*>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
lse_ptr
,
do_ptr
,
d_ptr
,
rand_val_ptr
,
dk_ptr
,
dv_ptr
,
dbias_ptr
,
dq_acc_ptr
,
seqstart_q_ptr
,
seqstart_k_ptr
,
seqlen_k_ptr
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
stride_dbias
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dbias
,
split_stride_dq_acc
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
4525c5d7
...
...
@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
__host__
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
...
...
@@ -267,8 +267,8 @@ struct FmhaFwdKernel
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaFwdGroupModeKargs
,
FmhaFwdBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
Impl
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
...
...
@@ -399,9 +399,191 @@ struct FmhaFwdKernel
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
float
scale_o
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
rand_val_ptr
,
lse_ptr
,
o_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_o
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_lse
,
batch_stride_o
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
s_randval
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
float
scale_o
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_bias
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
batch_stride_o
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
const
void
*
,
const
void
*>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
rand_val_ptr
,
lse_ptr
,
o_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_o
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_lse
,
batch_stride_o
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
s_randval
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargsImpl
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
...
...
@@ -522,7 +704,165 @@ struct FmhaFwdKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
float
scale_o
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
rand_val_ptr
,
lse_ptr
,
o_ptr
,
seqstart_q_ptr
,
seqstart_k_ptr
,
seqlen_k_ptr
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_o
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
s_randval
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
rand_val_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
float
scale_s
,
float
scale_p
,
float
scale_o
,
ck_tile
::
index_t
stride_q
,
ck_tile
::
index_t
stride_k
,
ck_tile
::
index_t
stride_v
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_q
,
ck_tile
::
index_t
nhead_stride_k
,
ck_tile
::
index_t
nhead_stride_v
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
const
void
*
,
const
void
*>&
drop_seed_offset
)
{
return
MakeKargsImpl
(
q_ptr
,
k_ptr
,
v_ptr
,
bias_ptr
,
rand_val_ptr
,
lse_ptr
,
o_ptr
,
seqstart_q_ptr
,
seqstart_k_ptr
,
seqlen_k_ptr
,
hdim_q
,
hdim_v
,
num_head_q
,
nhead_ratio_qk
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
stride_bias
,
stride_randval
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_o
,
window_size_left
,
window_size_right
,
mask_type
,
p_drop
,
s_randval
,
std
::
make_pair
(
std
::
get
<
0
>
(
drop_seed_offset
),
std
::
get
<
1
>
(
drop_seed_offset
)));
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
...
...
@@ -530,7 +870,7 @@ struct FmhaFwdKernel
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
4525c5d7
...
...
@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
...
...
@@ -46,8 +47,7 @@ struct FmhaFwdSplitKVKernel
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static_assert
(
!
kIsGroupMode
||
(
kIsGroupMode
&&
!
kIsPagedKV
),
"paged-kvcache only supported by batch mode kernels"
);
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
...
@@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel
float
scale_p
;
};
struct
PageBlockTableKargs
struct
Common
PageBlockTableKargs
{
const
int32_t
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
ck_tile
::
index_t
page_block_size
;
};
struct
GroupModePageBlockTableKargs
:
CommonPageBlockTableKargs
{
bool
is_gappy
=
false
;
};
struct
CacheBatchIdxKargs
{
const
int32_t
*
cache_batch_idx
;
...
...
@@ -193,13 +198,15 @@ struct FmhaFwdSplitKVKernel
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kIsPagedKV
,
PageBlockTableKargs
,
CacheBatchIdxKargs
>
std
::
conditional_t
<
kIsPagedKV
,
Common
PageBlockTableKargs
,
CacheBatchIdxKargs
>
{
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_k
;
// when using paged-kvcache, this will be stride/size for
// single kcache page-block
ck_tile
::
index_t
batch_stride_v
;
// when using paged-kvcache, this will be stride/size for
// single vcache page-block
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
...
...
@@ -212,14 +219,17 @@ struct FmhaFwdSplitKVKernel
AlibiKargs
,
EmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
MaskKargs
,
EmptyKargs
<
1
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
2
>>
,
std
::
conditional_t
<
kIsPagedKV
,
GroupModePageBlockTableKargs
,
EmptyKargs
<
3
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache, this will be stride/size
// for single kcache page-block
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache, this will be stride/size
// for single vcache page-block
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
...
@@ -230,8 +240,10 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
void
*
lse_acc_ptr
,
/* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void
*
o_acc_ptr
,
/* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
// only used if 'seqlen_k_ptr' is not specified
...
...
@@ -352,8 +364,10 @@ struct FmhaFwdSplitKVKernel
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
bias_ptr
,
void
*
lse_acc_ptr
,
void
*
o_acc_ptr
,
void
*
lse_acc_ptr
,
/* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void
*
o_acc_ptr
,
/* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile
::
index_t
batch
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
...
...
@@ -363,6 +377,10 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
num_head_q
,
ck_tile
::
index_t
nhead_ratio_qk
,
ck_tile
::
index_t
num_splits
,
const
void
*
block_table_ptr
,
ck_tile
::
index_t
batch_stride_block_table
,
ck_tile
::
index_t
page_block_size
,
bool
is_gappy
,
float
scale_s
,
float
scale_p
,
ck_tile
::
index_t
stride_q
,
...
...
@@ -416,6 +434,7 @@ struct FmhaFwdSplitKVKernel
{},
// placeholder for bias
{},
// placeholder for mask
{},
// placeholder for fp8_static_quant args
{},
// placeholder for paged-block table
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
...
...
@@ -443,6 +462,13 @@ struct FmhaFwdSplitKVKernel
{
kargs
.
scale_p
=
scale_p
;
}
if
constexpr
(
kIsPagedKV
)
{
kargs
.
block_table_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
block_table_ptr
);
kargs
.
batch_stride_block_table
=
batch_stride_block_table
;
kargs
.
page_block_size
=
page_block_size
;
kargs
.
is_gappy
=
is_gappy
;
}
return
kargs
;
}
...
...
@@ -476,11 +502,13 @@ struct FmhaFwdSplitKVKernel
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
long_index_t
batch_offset_q
=
0
;
long_index_t
batch_offset_k
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_k
=
0
;
// unused for paged-kvcache
long_index_t
batch_offset_v
=
0
;
// unused for paged-kvcache
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
index_t
kv_l2p_offset
=
0
;
// logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
if
constexpr
(
kIsGroupMode
)
{
...
...
@@ -490,7 +518,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -525,6 +552,15 @@ struct FmhaFwdSplitKVKernel
{
kargs
.
seqlen_k
=
kargs
.
seqstart_k_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_k_ptr
[
i_batch
];
}
if
constexpr
(
kIsPagedKV
)
{
if
(
kargs
.
is_gappy
)
{
// seqstart_k_ptr has different meaning in this case
kv_l2p_offset
=
kargs
.
seqstart_k_ptr
[
i_batch
];
}
}
}
else
{
...
...
@@ -570,7 +606,7 @@ struct FmhaFwdSplitKVKernel
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
O
acc
DataType
*
o_acc_ptr
=
reinterpret_cast
<
O
acc
DataType
*>
(
kargs
.
o_acc_ptr
)
+
ODataType
*
o_acc_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
...
...
@@ -677,7 +713,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
...
...
@@ -685,14 +721,15 @@ struct FmhaFwdSplitKVKernel
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
batch_stride_k
,
kargs
.
batch_stride_k
,
// kcache page-block stride/size
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
k_dram
,
make_k_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
(
kv_l2p_offset
+
kargs
.
seqlen_k
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
...
...
@@ -707,7 +744,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
kargs
.
block_table_ptr
)
+
i_batch_
*
kargs
.
batch_stride_block_table
;
const
index_t
num_blocks
=
integer_divide_ceil
(
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
...
...
@@ -715,14 +752,15 @@ struct FmhaFwdSplitKVKernel
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
batch_stride_v
,
kargs
.
batch_stride_v
,
// vcache page-block stride/size
fixed_offset
,
block_indices
,
num_blocks
,
kargs
.
page_block_size
,
v_dram
,
make_v_dram
(
nullptr
,
kargs
.
seqlen_k
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
(
kv_l2p_offset
+
kargs
.
seqlen_k
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
}
else
{
...
...
@@ -870,6 +908,7 @@ struct FmhaFwdSplitKVKernel
mask
,
position_encoding
,
kargs
.
scale_s
,
kv_l2p_offset
,
smem_ptr
);
}
else
...
...
@@ -886,6 +925,7 @@ struct FmhaFwdSplitKVKernel
mask
,
position_encoding
,
kargs
.
scale_s
,
kv_l2p_offset
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
4525c5d7
...
...
@@ -18,7 +18,7 @@ struct FmhaFwdSplitKVTilePartitioner
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
4525c5d7
...
...
@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always s
tore
LSE
(acc)
static
constexpr
bool
kStoreLSE
=
Problem
::
kS
toreLSE
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
...
...
@@ -142,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
static_assert
(
...
...
@@ -211,15 +213,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
const
auto
[
logical_
seqlen_k_start
,
logical_
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
o
ri
gi
n
al_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
if
(
o
ri
gi
n
al_num_total_loop
<=
0
)
const
index_t
l
ogi
c
al_num_total_loop
=
integer_divide_ceil
(
logical_
seqlen_k_end
-
logical_
seqlen_k_start
,
kN0
);
if
(
l
ogi
c
al_num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -238,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
// make sure the first tile is completely located in page-block
const
index_t
adjusted_seqlen_k_start
=
[
&
,
seqlen_k_start_
=
seqlen_k_start
]
{
const
index_t
physical_seqlen_k_start
=
logical_seqlen_k_start
+
kv_l2p_offset
;
const
index_t
physical_seqlen_k_end
=
logical_seqlen_k_end
+
kv_l2p_offset
;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const
index_t
aligned_physical_seqlen_k_start
=
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
]
{
if
constexpr
(
kIsPagedKV
)
{
return
kN0
*
integer_divide_floor
(
seqlen_k_start_
,
kN0
);
return
kN0
*
integer_divide_floor
(
physical_
seqlen_k_start_
,
kN0
);
}
else
{
return
seqlen_k_start_
;
return
physical_
seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
a
djusted
_seqlen_k_start
,
kN0
);
integer_divide_ceil
(
physical_
seqlen_k_end
-
a
ligned_physical
_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
a
djusted
_seqlen_k_start
,
0
});
k_dram_block_window_lengths
,
{
a
ligned_physical
_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
logical_seqlen_k_start
-
(
physical_seqlen_k_start
-
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
a
djusted
_seqlen_k_start
},
// TODO: hdim split?
{
0
,
a
ligned_physical
_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
...
...
@@ -378,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
// position_encoding accept only logical coordinates, do conversion here
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
-
kv_l2p_offset
);
});
});
}
...
...
@@ -396,19 +407,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
seqlen_k_start_
=
seqlen_k_start
,
seqlen_k_end_
=
seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
seqlen_k_start_
||
seqlen_k_end_
<=
col
;
return
col
<
physical_
seqlen_k_start_
||
physical_
seqlen_k_end_
<=
col
;
}
else
{
return
seqlen_k_end_
<=
col
;
return
physical_
seqlen_k_end_
<=
col
;
}
});
}
...
...
@@ -417,8 +429,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
// mask accept only logical coordinates, do conversion here
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{})
-
kv_l2p_offset
,
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
...
...
@@ -427,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
return
mask
.
IsOutOfBound
(
row
,
col
-
kv_l2p_offset
);
});
}
}
...
...
@@ -658,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
...
...
@@ -680,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask
,
position_encoding
,
scale_s
,
kv_l2p_offset
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
4525c5d7
...
...
@@ -331,7 +331,8 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -355,6 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
...
...
@@ -386,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
auto
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
...
...
@@ -514,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -618,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
...
...
@@ -665,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
4525c5d7
...
...
@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kStoreLSE_
,
/* set to true if either num_splits > 1 or fwd training is running */
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
,
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment