Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4501f13
Commit
a4501f13
authored
Jan 21, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ck_tile_gemm_policy
parents
c6dcf20d
e7dce4d2
Changes
368
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1227 additions
and
457 deletions
+1227
-457
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+11
-1
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+125
-1
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+25
-16
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+11
-151
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
+30
-4
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-1
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-1
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+75
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+27
-4
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+22
-4
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+55
-31
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+2
-1
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
...ile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
+282
-234
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
+2
-3
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
+510
-0
No files found.
include/ck_tile/core/numeric/bfloat16.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
bfloat16_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
a4501f13
...
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
...
...
include/ck_tile/host.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/host/arg_parser.hpp
View file @
a4501f13
...
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
class
ArgParser
{
public:
class
Arg
{
...
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/host/check_err.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -18,6 +18,130 @@
namespace
ck_tile
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
a4501f13
...
...
@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
/down
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
...
...
@@ -82,19 +82,8 @@ void reference_fused_moe(
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
]
/
block_m
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
);
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
;
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
...
...
@@ -105,11 +94,31 @@ void reference_fused_moe(
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
ck_tile
::
index_t
i_topk
=
i_token
>>
24
;
i_token
&=
0xffffff
;
if
(
i_token
>=
tokens
)
return
;
(
void
)
token_ids_host
;
#else
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
#endif
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
a4501f13
...
...
@@ -97,9 +97,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_c
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
return
;
}
...
...
@@ -191,9 +125,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t
batch_stride_C
,
index_t
batch_count
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
View file @
a4501f13
...
...
@@ -8,16 +8,40 @@
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_rmsnorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
}
};
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
InvRmsDataType
>
typename
InvRmsDataType
,
typename
Epilogue
=
reference_rmsnorm2d_default_epilogue
>
void
reference_rmsnorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
InvRmsDataType
>&
invRms_m
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
Epilogue
epilogue_functor
=
{})
{
auto
rmsnorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvRmsDataType
,
ck_tile
::
null_type
>
)
invRms_m
(
m
)
=
ck_tile
::
type_convert
<
InvRmsDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
auto
y
=
x
*
divisor
*
gamma
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
acc
(
m
,
n
)
=
x
*
divisor
*
gamma
;
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
make_ParallelTensorFunctor
(
rmsnorm2d_fwd_func
,
invRms_m
.
mDesc
.
get_lengths
()[
0
])(
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/common.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/elementwise.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
a4501f13
...
...
@@ -719,7 +719,82 @@ struct Silu
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck_tile
::
exp
(
-
x
)));
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
constexpr
auto
one
=
type_convert
<
float
>
(
1
);
y
[
0
]
=
x
[
0
]
*
__builtin_amdgcn_rcpf
(
one
+
ck_tile
::
exp
(
-
x
[
0
]));
y
[
1
]
=
x
[
1
]
*
__builtin_amdgcn_rcpf
(
one
+
ck_tile
::
exp
(
-
x
[
1
]));
};
};
#if 0
// Silu, the formular is not so good to do inline asm (dependency)
// we put the code here purposely if in the future ppl want to try
struct SiluAsm
{
template <typename T>
CK_TILE_HOST void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
template <typename T>
CK_TILE_DEVICE void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// "+v" as y, "v" as x is not enought, x/y stil maybe put to same register
T tmp = x;
asm volatile("v_mul_f32 %[v_y], %[s_log2e], %[v_x]\n"
"v_exp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_y], %[v_y], 1.0\n"
"v_rcp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for rcp\n"
"v_mul_f32 %[v_y], %[v_x], %[v_y]\n"
: [v_y] "+v"(y), [v_x] "+v"(tmp)
: [s_log2e] "s"(log2e_neg_)
:);
};
template <>
CK_TILE_HOST void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
constexpr auto one = type_convert<float>(1);
y[0] = x[0] * (one / (one + ck_tile::exp(-x[0])));
y[1] = x[1] * (one / (one + ck_tile::exp(-x[1])));
};
template <>
CK_TILE_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// float tmp0 = x[0], tmp1 = x[1];
asm volatile("v_mul_f32 %[v_y0], %[s_log2e], %[v_x0]\n"
"v_mul_f32 %[v_y1], %[s_log2e], %[v_x1]\n"
"v_exp_f32 %[v_y0], %[v_y0]\n"
"v_exp_f32 %[v_y1], %[v_y1]\n"
"v_add_f32 %[v_y0], %[v_y0], 1.0\n"
"v_add_f32 %[v_y1], %[v_y1], 1.0\n"
"v_rcp_f32 %[v_y0], %[v_y0]\n"
"v_rcp_f32 %[v_y1], %[v_y1]\n"
"v_mul_f32 %[v_y0], %[v_x0], %[v_y0]\n"
"v_mul_f32 %[v_y1], %[v_x1], %[v_y1]\n"
: [v_y0] "+v"(y[0]), [v_y1] "+v"(y[1]), [v_x0] "+v"(x[0]), [v_x1] "+v"(x[1])
: [s_log2e] "s"(log2e_neg_)
:);
};
};
#endif
struct
TanH
{
...
...
include/ck_tile/ops/epilogue.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
...
...
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
...
...
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -35,21 +35,39 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
{
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -24,19 +24,19 @@ struct DynamicQuantEpilogueTraits
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
X
ScaleDataType_
,
typename
Smooth
ScaleDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
X
ScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
Smooth
ScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
// TODO: we should put descriptor creation function into policy
...
...
@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
...
...
@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO:
x
_scale create local-scratch to accept arbitrary acc input (with same length)
// TODO:
sm
_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
...
...
@@ -105,34 +105,18 @@ struct DynamicQuantEpilogue
return
reduce_crosswarp_sync
.
GetSmemSize
();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
Impl
(
ODramWindowTmp
&
o_dram_window_tmp
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
MakeSmoothInputScaleTileDistribution
());
auto
x_scale
=
load_tile
(
x_scale_window
);
auto
o_acc_tmp
=
o_acc_tile
;
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
AccDataType
>
(
x_scale
[
j_idx
]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
(
idx
)
*
xs_
;
});
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
auto
row_absmax
=
[
&
]()
{
...
...
@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tmp
));
}
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// Smooth Dynamic Quant
template
<
typename
ODramWindowTmp
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
SmoothScaleWindow
&
sm_scale_window_
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
const
auto
sm_scale_window
=
make_tile_window
(
sm_scale_window_
,
MakeSmoothInputScaleTileDistribution
());
auto
sm_scale
=
load_tile
(
sm_scale_window
);
auto
o_acc_tmp
=
o_acc_tile
;
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
AccDataType
>
(
sm_scale
[
j_idx
]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
(
idx
)
*
xs_
;
});
Impl
(
o_dram_window_tmp
,
y_scale_window
,
o_acc_tmp
,
smem
);
}
// Dynamic Quant
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
Impl
(
o_dram_window_tmp
,
y_scale_window
,
o_acc_tile
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm.hpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
View file @
a4501f13
...
...
@@ -234,10 +234,153 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
32
*
(
128
+
8
)
*
sizeof
(
bf16_t
);
// return 32 * (128 + 8) * sizeof(bf16_t);
return
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
bf16_t
)
*
2
;
// 2 lds buffers
}
};
// clang-format off
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
[s_loop_cnt]"+s"(loop_cnt), \
[v_acc_0]"+v"(v_acc[0]), \
[v_acc_1]"+v"(v_acc[1]), \
[v_acc_2]"+v"(v_acc[2]), \
[v_acc_3]"+v"(v_acc[3]), \
[v_acc_4]"+v"(v_acc[4]), \
[v_acc_5]"+v"(v_acc[5]), \
[v_acc_6]"+v"(v_acc[6]), \
[v_acc_7]"+v"(v_acc[7]), \
[v_acc_8]"+v"(v_acc[8]), \
[v_acc_9]"+v"(v_acc[9]), \
[v_acc_10]"+v"(v_acc[10]), \
[v_acc_11]"+v"(v_acc[11]), \
[v_acc_12]"+v"(v_acc[12]), \
[v_acc_13]"+v"(v_acc[13]), \
[v_acc_14]"+v"(v_acc[14]), \
[v_acc_15]"+v"(v_acc[15]), \
[s_mem_]"+r"(smem)
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
[s_loop_cnt]"+s"(loop_cnt), \
[v_acc_0]"+v"(v_acc[0]), \
[v_acc_1]"+v"(v_acc[1]), \
[v_acc_2]"+v"(v_acc[2]), \
[v_acc_3]"+v"(v_acc[3]), \
[v_acc_4]"+v"(v_acc[4]), \
[v_acc_5]"+v"(v_acc[5]), \
[v_acc_6]"+v"(v_acc[6]), \
[v_acc_7]"+v"(v_acc[7]), \
[v_acc_8]"+v"(v_acc[8]), \
[v_acc_9]"+v"(v_acc[9]), \
[v_acc_10]"+v"(v_acc[10]), \
[v_acc_11]"+v"(v_acc[11]), \
[v_acc_12]"+v"(v_acc[12]), \
[v_acc_13]"+v"(v_acc[13]), \
[v_acc_14]"+v"(v_acc[14]), \
[v_acc_15]"+v"(v_acc[15]), \
[v_acc_16]"+v"(v_acc[16]), \
[v_acc_17]"+v"(v_acc[17]), \
[v_acc_18]"+v"(v_acc[18]), \
[v_acc_19]"+v"(v_acc[19]), \
[v_acc_20]"+v"(v_acc[20]), \
[v_acc_21]"+v"(v_acc[21]), \
[v_acc_22]"+v"(v_acc[22]), \
[v_acc_23]"+v"(v_acc[23]), \
[v_acc_24]"+v"(v_acc[24]), \
[v_acc_25]"+v"(v_acc[25]), \
[v_acc_26]"+v"(v_acc[26]), \
[v_acc_27]"+v"(v_acc[27]), \
[v_acc_28]"+v"(v_acc[28]), \
[v_acc_29]"+v"(v_acc[29]), \
[v_acc_30]"+v"(v_acc[30]), \
[v_acc_31]"+v"(v_acc[31]), \
[s_mem_]"+r"(smem)
#define _EXPAND_ASM_ARGS_IN \
[s_res_a0]"s"(res_a[0]), \
[s_res_a1]"s"(res_a[1]), \
[s_res_a2]"s"(res_a[2]), \
[s_res_a3]"s"(res_a[3]), \
[s_res_b0]"s"(res_b[0]), \
[s_res_b1]"s"(res_b[1]), \
[s_res_b2]"s"(res_b[2]), \
[s_res_b3]"s"(res_b[3]), \
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
\
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
\
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
[s_m0_init]"s"(m0_init_value), \
[s_size_per_issue]"s"(size_per_issue), \
[smem_sz]"n"(smem_buf_size), \
[sld_os_0]"n"(sld_os[number<0>{}].value), \
[sld_os_1]"n"(sld_os[number<1>{}].value), \
[sld_os_2]"n"(sld_os[number<2>{}].value), \
[sld_os_3]"n"(sld_os[number<3>{}].value), \
[sld_os_4]"n"(sld_os[number<4>{}].value), \
[sld_os_5]"n"(sld_os[number<5>{}].value), \
[sld_os_6]"n"(sld_os[number<6>{}].value), \
[sld_os_7]"n"(sld_os[number<7>{}].value), \
[s_tile_os_a]"s"(tile_offset_a_bytes), \
[s_tile_os_b]"s"(tile_offset_b_bytes)
#define _EXPAND_ASM_ARGS_CLOBBER \
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
"a252", "a253", "a254", "a255", \
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
"s86", \
"v64", "v65", "v66", "v67", "v68", "v69", \
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
"v124", "v125", "v126", "v127"
// clang-format on
struct
Flatmm_32x512x128_1x4x1_16x16x32_BF16
:
public
Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using
ADataType
=
bf16_t
;
...
...
@@ -245,7 +388,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
// Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
// we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
,
bool
Is2B
=
false
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
...
...
@@ -254,7 +399,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
,
bool_constant
<
Is2B
>
=
{})
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
...
...
@@ -299,129 +445,78 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
if
constexpr
(
Is2B
)
{
// this is the acc thread buffer
fp32x4_t
v_acc
[
32
]{
.0
f
};
// B nr->kr
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#define CK_TILE_FLATMM_UK_2B 1
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
:
_EXPAND_ASM_ARGS_OUT_TWO_ACC
:
_EXPAND_ASM_ARGS_IN
,
[
s_res_b4
]
"s"
(
res_b
[
4
]),
[
s_res_b5
]
"s"
(
res_b
[
5
]),
[
s_res_b6
]
"s"
(
res_b
[
6
]),
[
s_res_b7
]
"s"
(
res_b
[
7
])
:
_EXPAND_ASM_ARGS_CLOBBER
,
"s24"
,
"s25"
,
"s26"
,
"s27"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
// return local scratch
auto
c
=
make_tuple
(
MakeCBlockTile
(),
MakeCBlockTile
());
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
16
+
i
].
x
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
16
+
i
].
y
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
16
+
i
].
z
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
16
+
i
].
w
;
}
return
c
;
}
else
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
:
_EXPAND_ASM_ARGS_OUT_ONE_ACC
:
_EXPAND_ASM_ARGS_IN
:
_EXPAND_ASM_ARGS_CLOBBER
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
return
c
;
}
};
...
...
@@ -432,7 +527,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
,
bool
Is2B
=
false
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
...
...
@@ -441,7 +536,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
,
// for each tile, the offset to move for each unroll
bool_constant
<
Is2B
>
=
{})
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
...
...
@@ -486,130 +582,82 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
if
constexpr
(
Is2B
)
{
// this is the acc thread buffer
fp32x4_t
v_acc
[
32
]{
.0
f
};
// B nr->kr
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#define CK_TILE_FLATMM_UK_2B 1
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
:
_EXPAND_ASM_ARGS_OUT_TWO_ACC
:
_EXPAND_ASM_ARGS_IN
,
[
s_res_b4
]
"s"
(
res_b
[
4
]),
[
s_res_b5
]
"s"
(
res_b
[
5
]),
[
s_res_b6
]
"s"
(
res_b
[
6
]),
[
s_res_b7
]
"s"
(
res_b
[
7
])
:
_EXPAND_ASM_ARGS_CLOBBER
,
"s24"
,
"s25"
,
"s26"
,
"s27"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
// return local scratch
auto
c
=
make_tuple
(
MakeCBlockTile
(),
MakeCBlockTile
());
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
at
(
number
<
0
>
{}).
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
16
+
i
].
x
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
16
+
i
].
y
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
16
+
i
].
z
;
c
.
at
(
number
<
1
>
{}).
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
16
+
i
].
w
;
}
return
c
;
}
else
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
:
_EXPAND_ASM_ARGS_OUT_ONE_ACC
:
_EXPAND_ASM_ARGS_IN
:
_EXPAND_ASM_ARGS_CLOBBER
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
return
c
;
}
};
#undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
#undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
#undef _EXPAND_ASM_ARGS_IN
#undef _EXPAND_ASM_ARGS_CLOBBER
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
View file @
a4501f13
...
...
@@ -65,7 +65,8 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return
2
*
2
*
4
*
4
*
(
16
*
4
+
4
)
*
sizeof
(
bf16_t
);
constexpr
index_t
nbufs
=
2
;
return
2
*
2
*
4
*
4
*
(
16
*
4
+
4
)
*
sizeof
(
bf16_t
)
*
nbufs
;
}
};
...
...
@@ -173,7 +174,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
c0
]
"+v"
(
v_c0
),
...
...
@@ -418,7 +418,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
c0
]
"+v"
(
v_c0
),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
0 → 100644
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
namespace
ck_tile
{
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl
:
public
FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
BRes
,
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_o
)
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
float
s1
=
scale_
[
number
<
1
>
{}];
// index_t loop_cnt = n / Block_N;
register
float
v_c0
asm
(
"v64"
);
register
float
v_c1
asm
(
"v65"
);
register
float
v_c2
asm
(
"v66"
);
register
float
v_c3
asm
(
"v67"
);
register
float
v_c4
asm
(
"v68"
);
register
float
v_c5
asm
(
"v69"
);
register
float
v_c6
asm
(
"v70"
);
register
float
v_c7
asm
(
"v71"
);
register
float
v_c8
asm
(
"v72"
);
register
float
v_c9
asm
(
"v73"
);
register
float
v_c10
asm
(
"v74"
);
register
float
v_c11
asm
(
"v75"
);
register
float
v_c12
asm
(
"v76"
);
register
float
v_c13
asm
(
"v77"
);
register
float
v_c14
asm
(
"v78"
);
register
float
v_c15
asm
(
"v79"
);
register
float
v_c16
asm
(
"v80"
);
register
float
v_c17
asm
(
"v81"
);
register
float
v_c18
asm
(
"v82"
);
register
float
v_c19
asm
(
"v83"
);
register
float
v_c20
asm
(
"v84"
);
register
float
v_c21
asm
(
"v85"
);
register
float
v_c22
asm
(
"v86"
);
register
float
v_c23
asm
(
"v87"
);
register
float
v_c24
asm
(
"v88"
);
register
float
v_c25
asm
(
"v89"
);
register
float
v_c26
asm
(
"v90"
);
register
float
v_c27
asm
(
"v91"
);
register
float
v_c28
asm
(
"v92"
);
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int
lane_id
=
threadIdx
.
x
%
64
;
int
sld_y_os
=
(
lane_id
%
16
)
*
4
+
(
lane_id
/
16
)
*
128
;
sld_y_os
*=
2
;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int
sfl_sst
=
(
threadIdx
.
x
%
16
*
4
)
+
(
threadIdx
.
x
/
16
)
*
(
64
+
4
);
sfl_sst
*=
2
;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int
sfl_sld
=
(
lane_id
%
2
)
*
2
+
(
lane_id
/
2
)
*
(
64
+
4
)
+
(
threadIdx
.
x
/
64
)
*
4
;
sfl_sld
*=
2
;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
// [s_loop_cnt]"+s"(loop_cnt),
[
s_loop_cnt
]
"+s"
(
n
),
[
c0
]
"+v"
(
v_c0
),
[
c1
]
"+v"
(
v_c1
),
[
c2
]
"+v"
(
v_c2
),
[
c3
]
"+v"
(
v_c3
),
[
c4
]
"+v"
(
v_c4
),
[
c5
]
"+v"
(
v_c5
),
[
c6
]
"+v"
(
v_c6
),
[
c7
]
"+v"
(
v_c7
),
[
c8
]
"+v"
(
v_c8
),
[
c9
]
"+v"
(
v_c9
),
[
c10
]
"+v"
(
v_c10
),
[
c11
]
"+v"
(
v_c11
),
[
c12
]
"+v"
(
v_c12
),
[
c13
]
"+v"
(
v_c13
),
[
c14
]
"+v"
(
v_c14
),
[
c15
]
"+v"
(
v_c15
),
[
c16
]
"+v"
(
v_c16
),
[
c17
]
"+v"
(
v_c17
),
[
c18
]
"+v"
(
v_c18
),
[
c19
]
"+v"
(
v_c19
),
[
c20
]
"+v"
(
v_c20
),
[
c21
]
"+v"
(
v_c21
),
[
c22
]
"+v"
(
v_c22
),
[
c23
]
"+v"
(
v_c23
),
[
c24
]
"+v"
(
v_c24
),
[
c25
]
"+v"
(
v_c25
),
[
c26
]
"+v"
(
v_c26
),
[
c27
]
"+v"
(
v_c27
),
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
3
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
4
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
5
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
6
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
7
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s8"
,
"s9"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s38"
,
"s39"
,
"s52"
,
"s86"
,
"s36"
,
"s37"
,
"s59"
,
"s80"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v50"
,
"v54"
,
"v55"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
"v148"
,
"v149"
,
"v150"
,
"v151"
,
"v152"
,
"v153"
,
"v154"
,
"v155"
,
"v156"
,
"v157"
,
"v158"
,
"v159"
,
"v160"
,
"v161"
,
"v162"
,
"v163"
,
"v164"
,
"v165"
,
"v166"
,
"v167"
,
"v168"
,
"v169"
,
"v170"
,
"v171"
,
"v172"
,
"v173"
,
"v174"
,
"v175"
,
"v176"
,
"v177"
,
"v178"
,
"v179"
,
"v180"
,
"v181"
,
"v182"
,
"v183"
,
"v184"
,
"v185"
,
"v186"
,
"v187"
,
"v188"
,
"v189"
,
"v190"
,
"v191"
,
"v192"
,
"v193"
,
"v194"
,
"v195"
,
"v196"
,
"v197"
,
"v198"
,
"v199"
,
"v200"
,
"v201"
,
"v202"
,
"v203"
,
"v204"
,
"v205"
,
"v206"
,
"v207"
,
"v208"
,
"v209"
,
"v210"
,
"v211"
,
"v212"
,
"v213"
,
"v214"
,
"v215"
,
"v216"
,
"v217"
,
"v218"
,
"v219"
,
"v220"
,
"v221"
,
"v222"
,
"v223"
,
"v224"
,
"v225"
,
"v226"
,
"v227"
,
"v228"
,
"v229"
,
"v230"
,
"v231"
,
"v232"
,
"v233"
,
"v234"
,
"v235"
,
"v236"
,
"v237"
,
"v238"
,
"v239"
,
"v240"
,
"v241"
,
"v242"
,
"v243"
,
"v244"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl
:
public
FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
BRes
,
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_o
)
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
float
s1
=
scale_
[
number
<
1
>
{}];
// index_t loop_cnt = n / Block_N;
register
float
v_c0
asm
(
"v64"
);
register
float
v_c1
asm
(
"v65"
);
register
float
v_c2
asm
(
"v66"
);
register
float
v_c3
asm
(
"v67"
);
register
float
v_c4
asm
(
"v68"
);
register
float
v_c5
asm
(
"v69"
);
register
float
v_c6
asm
(
"v70"
);
register
float
v_c7
asm
(
"v71"
);
register
float
v_c8
asm
(
"v72"
);
register
float
v_c9
asm
(
"v73"
);
register
float
v_c10
asm
(
"v74"
);
register
float
v_c11
asm
(
"v75"
);
register
float
v_c12
asm
(
"v76"
);
register
float
v_c13
asm
(
"v77"
);
register
float
v_c14
asm
(
"v78"
);
register
float
v_c15
asm
(
"v79"
);
register
float
v_c16
asm
(
"v80"
);
register
float
v_c17
asm
(
"v81"
);
register
float
v_c18
asm
(
"v82"
);
register
float
v_c19
asm
(
"v83"
);
register
float
v_c20
asm
(
"v84"
);
register
float
v_c21
asm
(
"v85"
);
register
float
v_c22
asm
(
"v86"
);
register
float
v_c23
asm
(
"v87"
);
register
float
v_c24
asm
(
"v88"
);
register
float
v_c25
asm
(
"v89"
);
register
float
v_c26
asm
(
"v90"
);
register
float
v_c27
asm
(
"v91"
);
register
float
v_c28
asm
(
"v92"
);
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int
lane_id
=
threadIdx
.
x
%
64
;
int
sld_y_os
=
(
lane_id
%
16
)
*
4
+
(
lane_id
/
16
)
*
128
;
sld_y_os
*=
2
;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int
sfl_sst
=
(
threadIdx
.
x
%
16
*
4
)
+
(
threadIdx
.
x
/
16
)
*
(
64
+
4
);
sfl_sst
*=
2
;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int
sfl_sld
=
(
lane_id
%
2
)
*
2
+
(
lane_id
/
2
)
*
(
64
+
4
)
+
(
threadIdx
.
x
/
64
)
*
4
;
sfl_sld
*=
2
;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
n
),
[
c0
]
"+v"
(
v_c0
),
[
c1
]
"+v"
(
v_c1
),
[
c2
]
"+v"
(
v_c2
),
[
c3
]
"+v"
(
v_c3
),
[
c4
]
"+v"
(
v_c4
),
[
c5
]
"+v"
(
v_c5
),
[
c6
]
"+v"
(
v_c6
),
[
c7
]
"+v"
(
v_c7
),
[
c8
]
"+v"
(
v_c8
),
[
c9
]
"+v"
(
v_c9
),
[
c10
]
"+v"
(
v_c10
),
[
c11
]
"+v"
(
v_c11
),
[
c12
]
"+v"
(
v_c12
),
[
c13
]
"+v"
(
v_c13
),
[
c14
]
"+v"
(
v_c14
),
[
c15
]
"+v"
(
v_c15
),
[
c16
]
"+v"
(
v_c16
),
[
c17
]
"+v"
(
v_c17
),
[
c18
]
"+v"
(
v_c18
),
[
c19
]
"+v"
(
v_c19
),
[
c20
]
"+v"
(
v_c20
),
[
c21
]
"+v"
(
v_c21
),
[
c22
]
"+v"
(
v_c22
),
[
c23
]
"+v"
(
v_c23
),
[
c24
]
"+v"
(
v_c24
),
[
c25
]
"+v"
(
v_c25
),
[
c26
]
"+v"
(
v_c26
),
[
c27
]
"+v"
(
v_c27
),
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
3
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
4
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
5
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
6
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
7
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s8"
,
"s9"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s38"
,
"s39"
,
"s52"
,
"s86"
,
"s36"
,
"s37"
,
"s56"
,
"s59"
,
"s60"
,
"s80"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v50"
,
"v54"
,
"v55"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
"v148"
,
"v149"
,
"v150"
,
"v151"
,
"v152"
,
"v153"
,
"v154"
,
"v155"
,
"v156"
,
"v157"
,
"v158"
,
"v159"
,
"v160"
,
"v161"
,
"v162"
,
"v163"
,
"v164"
,
"v165"
,
"v166"
,
"v167"
,
"v168"
,
"v169"
,
"v170"
,
"v171"
,
"v172"
,
"v173"
,
"v174"
,
"v175"
,
"v176"
,
"v177"
,
"v178"
,
"v179"
,
"v180"
,
"v181"
,
"v182"
,
"v183"
,
"v184"
,
"v185"
,
"v186"
,
"v187"
,
"v188"
,
"v189"
,
"v190"
,
"v191"
,
"v192"
,
"v193"
,
"v194"
,
"v195"
,
"v196"
,
"v197"
,
"v198"
,
"v199"
,
"v200"
,
"v201"
,
"v202"
,
"v203"
,
"v204"
,
"v205"
,
"v206"
,
"v207"
,
"v208"
,
"v209"
,
"v210"
,
"v211"
,
"v212"
,
"v213"
,
"v214"
,
"v215"
,
"v216"
,
"v217"
,
"v218"
,
"v219"
,
"v220"
,
"v221"
,
"v222"
,
"v223"
,
"v224"
,
"v225"
,
"v226"
,
"v227"
,
"v228"
,
"v229"
,
"v230"
,
"v231"
,
"v232"
,
"v233"
,
"v234"
,
"v235"
,
"v236"
,
"v237"
,
"v238"
,
"v239"
,
"v240"
,
"v241"
,
"v242"
,
"v243"
,
"v244"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
}
// namespace ck_tile
Prev
1
…
4
5
6
7
8
9
10
11
12
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment