Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2837e6b3
Unverified
Commit
2837e6b3
authored
Sep 14, 2023
by
Chao Liu
Committed by
GitHub
Sep 14, 2023
Browse files
Batch gemm softmax gemm (#11)
* make it simple * batched gemm+softmax+gemm
parent
6bc9ee05
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
970 additions
and
512 deletions
+970
-512
example/91_tile_program/CMakeLists.txt
example/91_tile_program/CMakeLists.txt
+1
-0
example/91_tile_program/batched_gemm_softmax_gemm.cpp
example/91_tile_program/batched_gemm_softmax_gemm.cpp
+166
-0
example/91_tile_program/batched_gemm_softmax_gemm.hpp
example/91_tile_program/batched_gemm_softmax_gemm.hpp
+109
-0
example/91_tile_program/gemm.cpp
example/91_tile_program/gemm.cpp
+23
-18
example/91_tile_program/gemm_gemm.cpp
example/91_tile_program/gemm_gemm.cpp
+36
-31
example/91_tile_program/gemm_gemm.hpp
example/91_tile_program/gemm_gemm.hpp
+3
-2
example/91_tile_program/gemm_softmax_gemm.cpp
example/91_tile_program/gemm_softmax_gemm.cpp
+84
-81
example/91_tile_program/gemm_softmax_gemm.hpp
example/91_tile_program/gemm_softmax_gemm.hpp
+53
-348
example/91_tile_program/gemm_softmax_gemm_impl.hpp
example/91_tile_program/gemm_softmax_gemm_impl.hpp
+358
-0
example/91_tile_program/reference_batched_gemm.hpp
example/91_tile_program/reference_batched_gemm.hpp
+36
-0
example/91_tile_program/reference_batched_softmax.hpp
example/91_tile_program/reference_batched_softmax.hpp
+46
-0
example/91_tile_program/reference_gemm.hpp
example/91_tile_program/reference_gemm.hpp
+16
-14
example/91_tile_program/softmax.hpp
example/91_tile_program/softmax.hpp
+8
-12
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-2
include/ck/tile_program/block_tile/block_reduce.hpp
include/ck/tile_program/block_tile/block_reduce.hpp
+3
-3
include/ck/tile_program/tile/static_distributed_tensor.hpp
include/ck/tile_program/tile/static_distributed_tensor.hpp
+26
-1
No files found.
example/91_tile_program/CMakeLists.txt
View file @
2837e6b3
...
@@ -4,3 +4,4 @@ add_example_executable(example_gemm_gemm gemm_gemm.cpp)
...
@@ -4,3 +4,4 @@ add_example_executable(example_gemm_gemm gemm_gemm.cpp)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_softmax softmax.cpp
)
add_example_executable
(
example_softmax softmax.cpp
)
add_example_executable
(
example_gemm_softmax_gemm gemm_softmax_gemm.cpp
)
add_example_executable
(
example_gemm_softmax_gemm gemm_softmax_gemm.cpp
)
add_example_executable
(
example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp
)
example/91_tile_program/batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
2837e6b3
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "batched_gemm_softmax_gemm.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
QDataType
=
ck
::
half_t
;
using
KDataType
=
ck
::
half_t
;
using
VDataType
=
ck
::
half_t
;
using
SaccDataType
=
float
;
using
SMPLComputeDataType
=
float
;
using
PDataType
=
ck
::
half_t
;
using
OaccDataType
=
float
;
using
ODataType
=
ck
::
half_t
;
ck
::
index_t
Batch
=
16
;
ck
::
index_t
M0
=
4096
;
ck
::
index_t
N0
=
4096
;
ck
::
index_t
K0
=
128
;
ck
::
index_t
N1
=
128
;
if
(
argc
==
6
)
{
Batch
=
std
::
stoi
(
argv
[
1
]);
M0
=
std
::
stoi
(
argv
[
2
]);
N0
=
std
::
stoi
(
argv
[
3
]);
K0
=
std
::
stoi
(
argv
[
4
]);
N1
=
std
::
stoi
(
argv
[
5
]);
}
std
::
array
<
ck
::
index_t
,
3
>
q_lengths
{
Batch
,
M0
,
K0
};
std
::
array
<
ck
::
index_t
,
3
>
q_strides
{
M0
*
K0
,
K0
,
1
};
std
::
array
<
ck
::
index_t
,
3
>
k_lengths
{
Batch
,
N0
,
K0
};
std
::
array
<
ck
::
index_t
,
3
>
k_strides
{
N0
*
K0
,
K0
,
1
};
std
::
array
<
ck
::
index_t
,
3
>
v_lengths
{
Batch
,
N1
,
N0
};
std
::
array
<
ck
::
index_t
,
3
>
v_strides
{
N1
*
N0
,
N0
,
1
};
std
::
array
<
ck
::
index_t
,
3
>
s_lengths
{
Batch
,
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
3
>
s_strides
{
M0
*
N0
,
N0
,
1
};
std
::
array
<
ck
::
index_t
,
3
>
p_lengths
{
Batch
,
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
3
>
p_strides
{
M0
*
N0
,
N0
,
1
};
std
::
array
<
ck
::
index_t
,
3
>
o_lengths
{
Batch
,
M0
,
N1
};
std
::
array
<
ck
::
index_t
,
3
>
o_strides
{
M0
*
N1
,
N1
,
1
};
// host verify
Tensor
<
QDataType
>
q_host
(
q_lengths
,
q_strides
);
Tensor
<
KDataType
>
k_host
(
k_lengths
,
k_strides
);
Tensor
<
VDataType
>
v_host
(
v_lengths
,
v_strides
);
Tensor
<
SMPLComputeDataType
>
s_host_ref
(
s_lengths
,
s_strides
);
Tensor
<
PDataType
>
p_host_ref
(
p_lengths
,
p_strides
);
Tensor
<
ODataType
>
o_host_ref
(
o_lengths
,
o_strides
);
Tensor
<
ODataType
>
o_host_dev
(
o_lengths
,
o_strides
);
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
#else
ck
::
utils
::
FillUniformDistribution
<
QDataType
>
{
-
3.
f
,
3.
f
}(
q_host
);
ck
::
utils
::
FillUniformDistribution
<
KDataType
>
{
-
3.
f
,
3.
f
}(
k_host
);
ck
::
utils
::
FillUniformDistribution
<
VDataType
>
{
-
3.
f
,
3.
f
}(
v_host
);
#endif
// reference
reference_batched_gemm
<
QDataType
,
KDataType
,
SaccDataType
,
SMPLComputeDataType
>
(
q_host
,
k_host
,
s_host_ref
);
reference_batched_softmax
<
SMPLComputeDataType
,
SMPLComputeDataType
,
PDataType
>
(
s_host_ref
,
p_host_ref
);
reference_batched_gemm
<
PDataType
,
VDataType
,
OaccDataType
,
ODataType
>
(
p_host_ref
,
v_host
,
o_host_ref
);
DeviceMem
q_buf
(
sizeof
(
QDataType
)
*
q_host
.
GetElementSpaceSize
());
DeviceMem
k_buf
(
sizeof
(
KDataType
)
*
k_host
.
GetElementSpaceSize
());
DeviceMem
v_buf
(
sizeof
(
VDataType
)
*
v_host
.
GetElementSpaceSize
());
DeviceMem
o_buf
(
sizeof
(
ODataType
)
*
o_host_ref
.
GetElementSpaceSize
());
q_buf
.
ToDevice
(
q_host
.
mData
.
data
());
k_buf
.
ToDevice
(
k_host
.
mData
.
data
());
v_buf
.
ToDevice
(
v_host
.
mData
.
data
());
constexpr
ck
::
index_t
kM0PerBlock
=
128
;
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
constexpr
ck
::
index_t
kK0PerBlock
=
32
;
constexpr
ck
::
index_t
kN1PerBlock
=
128
;
constexpr
ck
::
index_t
kBlockSize
=
256
;
ck
::
index_t
kGridSize
=
Batch
*
(
M0
/
kM0PerBlock
)
*
(
N1
/
kN1PerBlock
);
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
constexpr
ck
::
index_t
kWarpPerCu
=
8
;
// 2 warps per SIMD
constexpr
ck
::
index_t
kWarpPerBlock
=
kBlockSize
/
warpSize
;
constexpr
ck
::
index_t
kBlockPerCu
=
kWarpPerCu
/
kWarpPerBlock
;
float
ave_time
=
launch_kernel
<
kBlockSize
,
kBlockPerCu
>
(
StreamConfig
{
nullptr
,
true
},
BatchedGemmSoftmaxGemm
<
QDataType
,
KDataType
,
VDataType
,
SaccDataType
,
SMPLComputeDataType
,
PDataType
,
OaccDataType
,
ODataType
,
kBlockSize
,
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kN1PerBlock
>
{},
kGridSize
,
kBlockSize
,
0
,
static_cast
<
QDataType
*>
(
q_buf
.
GetDeviceBuffer
()),
static_cast
<
KDataType
*>
(
k_buf
.
GetDeviceBuffer
()),
static_cast
<
VDataType
*>
(
v_buf
.
GetDeviceBuffer
()),
static_cast
<
ODataType
*>
(
o_buf
.
GetDeviceBuffer
()),
M0
,
N0
,
K0
,
N1
,
Batch
,
K0
,
// StrideQ
K0
,
// StrideK
N0
,
// StrideV
N1
,
// StrideO
M0
*
K0
,
// BatchStrideQ
N0
*
K0
,
// BatchStrideK
N1
*
N0
,
// BatchStrideV
M0
*
N1
);
// BatchStrideO
o_buf
.
FromDevice
(
o_host_dev
.
mData
.
data
());
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
Batch
*
M0
*
N0
*
K0
+
std
::
size_t
(
2
)
*
Batch
*
M0
*
N1
*
N0
;
std
::
size_t
num_btype
=
sizeof
(
QDataType
)
*
Batch
*
M0
*
K0
+
sizeof
(
KDataType
)
*
Batch
*
N0
*
K0
+
sizeof
(
VDataType
)
*
Batch
*
N1
*
N0
+
sizeof
(
ODataType
)
*
Batch
*
M0
*
N1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
return
!
ck
::
utils
::
check_err
(
o_host_dev
,
o_host_ref
);
}
example/91_tile_program/batched_gemm_softmax_gemm.hpp
0 → 100644
View file @
2837e6b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
#include "gemm_softmax_gemm_impl.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template
<
typename
QDataType
,
typename
KDataType
,
typename
VDataType
,
typename
SaccDataType
,
typename
SMPLComputeDataType
,
typename
PDataType
,
typename
OaccDataType
,
typename
ODataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kN1PerBlock
>
struct
BatchedGemmSoftmaxGemm
{
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
const
KDataType
*
k_ptr
,
const
VDataType
*
v_ptr
,
ODataType
*
o_ptr
,
const
ck
::
index_t
M0
,
const
ck
::
index_t
N0
,
const
ck
::
index_t
K0
,
const
ck
::
index_t
N1
,
const
ck
::
index_t
/* Batch */
,
const
ck
::
index_t
StrideQ
,
const
ck
::
index_t
StrideK
,
const
ck
::
index_t
StrideV
,
const
ck
::
index_t
StrideO
,
const
ck
::
index_t
BatchStrideQ
,
const
ck
::
index_t
BatchStrideK
,
const
ck
::
index_t
BatchStrideV
,
const
ck
::
index_t
BatchStrideO
)
const
{
using
namespace
ck
;
// divide problem
const
index_t
num_tile_m0
=
M0
/
kM0PerBlock
;
const
index_t
num_tile_n1
=
N1
/
kN1PerBlock
;
const
index_t
id_block
=
get_block_id
();
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
itmp
,
id_tile_n
]
=
f
(
id_block
,
num_tile_n1
);
const
auto
[
id_tile_batch
,
id_tile_m
]
=
f
(
itmp
,
num_tile_m0
);
const
index_t
iBatch
=
__builtin_amdgcn_readfirstlane
(
id_tile_batch
);
const
index_t
iM0
=
__builtin_amdgcn_readfirstlane
(
id_tile_m
*
kM0PerBlock
);
const
index_t
iN1
=
__builtin_amdgcn_readfirstlane
(
id_tile_n
*
kN1PerBlock
);
const
auto
kernel_impl
=
GemmSoftmaxGemmImpl
<
QDataType
,
KDataType
,
VDataType
,
SaccDataType
,
SMPLComputeDataType
,
PDataType
,
OaccDataType
,
ODataType
,
kBlockSize
,
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kN1PerBlock
>
{};
kernel_impl
(
q_ptr
+
iBatch
*
BatchStrideQ
,
k_ptr
+
iBatch
*
BatchStrideK
,
v_ptr
+
iBatch
*
BatchStrideV
,
o_ptr
+
iBatch
*
BatchStrideO
,
M0
,
N0
,
K0
,
N1
,
StrideQ
,
StrideK
,
StrideV
,
StrideO
,
iM0
,
iN1
);
}
};
example/91_tile_program/gemm.cpp
View file @
2837e6b3
...
@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
...
@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
// reference gemm
// reference gemm
reference_gemm
<
ADataType
,
ADataType
,
C
DataType
,
float
>
(
a_host
,
b_host
,
c_host_ref
);
reference_gemm
<
ADataType
,
ADataType
,
Acc
DataType
,
CDataType
>
(
a_host
,
b_host
,
c_host_ref
);
DeviceMem
a_buf
(
sizeof
(
ADataType
)
*
a_host
.
GetElementSpaceSize
());
DeviceMem
a_buf
(
sizeof
(
ADataType
)
*
a_host
.
GetElementSpaceSize
());
DeviceMem
b_buf
(
sizeof
(
BDataType
)
*
b_host
.
GetElementSpaceSize
());
DeviceMem
b_buf
(
sizeof
(
BDataType
)
*
b_host
.
GetElementSpaceSize
());
...
@@ -99,6 +99,10 @@ int main(int argc, char* argv[])
...
@@ -99,6 +99,10 @@ int main(int argc, char* argv[])
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
constexpr
ck
::
index_t
kWarpPerCu
=
8
;
// 2 warps per SIMD
constexpr
ck
::
index_t
kWarpPerBlock
=
kBlockSize
/
warpSize
;
constexpr
ck
::
index_t
kBlockPerCu
=
kWarpPerCu
/
kWarpPerBlock
;
const
auto
gemm_kernel
=
Gemm
<
ADataType
,
const
auto
gemm_kernel
=
Gemm
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
...
@@ -114,23 +118,24 @@ int main(int argc, char* argv[])
...
@@ -114,23 +118,24 @@ int main(int argc, char* argv[])
kGemmNPerBlock
,
kGemmNPerBlock
,
kGemmKPerBlock
>
{};
kGemmKPerBlock
>
{};
float
ave_time
=
launch_kernel
<
kBlockSize
,
2
>
(
StreamConfig
{
nullptr
,
true
},
float
ave_time
=
gemm_kernel
,
launch_kernel
<
kBlockSize
,
kBlockPerCu
>
(
StreamConfig
{
nullptr
,
true
},
kGridSize
,
gemm_kernel
,
kBlockSize
,
kGridSize
,
0
,
kBlockSize
,
static_cast
<
ADataType
*>
(
a_buf
.
GetDeviceBuffer
()),
0
,
static_cast
<
BDataType
*>
(
b_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_buf
.
GetDeviceBuffer
()),
M
,
static_cast
<
CDataType
*>
(
c_buf
.
GetDeviceBuffer
()),
N
,
M
,
K
,
N
,
K
,
K
,
K
,
K
,
N
,
K
,
AElementFunction
{},
N
,
BElementFunction
{},
AElementFunction
{},
CElementFunction
{});
BElementFunction
{},
CElementFunction
{});
c_buf
.
FromDevice
(
c_host_dev
.
mData
.
data
());
c_buf
.
FromDevice
(
c_host_dev
.
mData
.
data
());
...
...
example/91_tile_program/gemm_gemm.cpp
View file @
2837e6b3
...
@@ -20,9 +20,9 @@ int main(int argc, char* argv[])
...
@@ -20,9 +20,9 @@ int main(int argc, char* argv[])
{
{
using
A0DataType
=
ck
::
half_t
;
using
A0DataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
Acc0DataType
=
float
;
using
Acc0DataType
=
float
;
using
C0DataType
=
ck
::
half_t
;
using
C0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
Acc1DataType
=
float
;
using
Acc1DataType
=
float
;
using
C1DataType
=
ck
::
half_t
;
using
C1DataType
=
ck
::
half_t
;
...
@@ -67,8 +67,9 @@ int main(int argc, char* argv[])
...
@@ -67,8 +67,9 @@ int main(int argc, char* argv[])
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B1DataType
>
{
-
3.
f
,
3.
f
}(
b1_host
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B1DataType
>
{
-
3.
f
,
3.
f
}(
b1_host
);
// reference gemm
// reference gemm
reference_gemm
<
A0DataType
,
B0DataType
,
C0DataType
,
float
>
(
a0_host
,
b0_host
,
c0_host_ref
);
reference_gemm
<
A0DataType
,
B0DataType
,
Acc0DataType
,
C0DataType
>
(
a0_host
,
b0_host
,
c0_host_ref
);
reference_gemm
<
C0DataType
,
B1DataType
,
C1DataType
,
float
>
(
c0_host_ref
,
b1_host
,
c1_host_ref
);
reference_gemm
<
C0DataType
,
B1DataType
,
Acc1DataType
,
C1DataType
>
(
c0_host_ref
,
b1_host
,
c1_host_ref
);
DeviceMem
a0_buf
(
sizeof
(
A0DataType
)
*
a0_host
.
GetElementSpaceSize
());
DeviceMem
a0_buf
(
sizeof
(
A0DataType
)
*
a0_host
.
GetElementSpaceSize
());
DeviceMem
b0_buf
(
sizeof
(
B0DataType
)
*
b0_host
.
GetElementSpaceSize
());
DeviceMem
b0_buf
(
sizeof
(
B0DataType
)
*
b0_host
.
GetElementSpaceSize
());
...
@@ -89,35 +90,39 @@ int main(int argc, char* argv[])
...
@@ -89,35 +90,39 @@ int main(int argc, char* argv[])
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
constexpr
ck
::
index_t
kWarpPerCu
=
8
;
// 2 warps per SIMD
constexpr
ck
::
index_t
kWarpPerBlock
=
kBlockSize
/
warpSize
;
constexpr
ck
::
index_t
kBlockPerCu
=
kWarpPerCu
/
kWarpPerBlock
;
float
ave_time
=
float
ave_time
=
launch_kernel
<
kBlockSize
,
2
>
(
StreamConfig
{
nullptr
,
true
},
launch_kernel
<
kBlockSize
,
kBlockPerCu
>
(
StreamConfig
{
nullptr
,
true
},
GemmGemm
<
A0DataType
,
GemmGemm
<
A0DataType
,
B0DataType
,
B0DataType
,
Acc0
DataType
,
B1
DataType
,
C
0DataType
,
Acc
0DataType
,
B1
DataType
,
C0
DataType
,
Acc1DataType
,
Acc1DataType
,
C1DataType
,
C1DataType
,
kBlockSize
,
kBlockSize
,
kM0PerBlock
,
kM0PerBlock
,
kN0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kK0PerBlock
,
kN1PerBlock
>
{},
kN1PerBlock
>
{},
kGridSize
,
kGridSize
,
kBlockSize
,
kBlockSize
,
0
,
0
,
static_cast
<
A0DataType
*>
(
a0_buf
.
GetDeviceBuffer
()),
static_cast
<
A0DataType
*>
(
a0_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_buf
.
GetDeviceBuffer
()),
static_cast
<
C1DataType
*>
(
c1_buf
.
GetDeviceBuffer
()),
static_cast
<
C1DataType
*>
(
c1_buf
.
GetDeviceBuffer
()),
M0
,
M0
,
N0
,
N0
,
K0
,
K0
,
N1
,
N1
,
K0
,
// Lda0
K0
,
// Lda0
K0
,
// Ldb0
K0
,
// Ldb0
N0
,
// Ldb1
N0
,
// Ldb1
N1
);
// Ldc1
N1
);
// Ldc1
c1_buf
.
FromDevice
(
c1_host_dev
.
mData
.
data
());
c1_buf
.
FromDevice
(
c1_host_dev
.
mData
.
data
());
...
...
example/91_tile_program/gemm_gemm.hpp
View file @
2837e6b3
...
@@ -16,12 +16,13 @@
...
@@ -16,12 +16,13 @@
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
// C1 = A0 * B0 * B1
// C0 = A0 * B0
// C1 = C0 * B1
template
<
typename
A0DataType
,
template
<
typename
A0DataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
Acc0DataType
,
typename
Acc0DataType
,
typename
C0DataType
,
typename
C0DataType
,
typename
B1DataType
,
typename
Acc1DataType
,
typename
Acc1DataType
,
typename
C1DataType
,
typename
C1DataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kBlockSize
,
...
...
example/91_tile_program/gemm_softmax_gemm.cpp
View file @
2837e6b3
...
@@ -19,14 +19,14 @@
...
@@ -19,14 +19,14 @@
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
A0
DataType
=
ck
::
half_t
;
using
Q
DataType
=
ck
::
half_t
;
using
B0
DataType
=
ck
::
half_t
;
using
K
DataType
=
ck
::
half_t
;
using
Acc0
DataType
=
floa
t
;
using
V
DataType
=
ck
::
half_
t
;
using
C0
DataType
=
ck
::
half_
t
;
using
Sacc
DataType
=
floa
t
;
using
D0DataType
=
ck
::
half_
t
;
using
SMPLComputeDataType
=
floa
t
;
using
B1
DataType
=
ck
::
half_t
;
using
P
DataType
=
ck
::
half_t
;
using
A
cc
1
DataType
=
float
;
using
Oa
ccDataType
=
float
;
using
C1
DataType
=
ck
::
half_t
;
using
O
DataType
=
ck
::
half_t
;
ck
::
index_t
M0
=
13312
;
ck
::
index_t
M0
=
13312
;
ck
::
index_t
N0
=
4096
;
ck
::
index_t
N0
=
4096
;
...
@@ -41,56 +41,57 @@ int main(int argc, char* argv[])
...
@@ -41,56 +41,57 @@ int main(int argc, char* argv[])
N1
=
std
::
stoi
(
argv
[
4
]);
N1
=
std
::
stoi
(
argv
[
4
]);
}
}
std
::
array
<
ck
::
index_t
,
2
>
a0
_lengths
{
M0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
q
_lengths
{
M0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
a0
_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
q
_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b0
_lengths
{
N0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
k
_lengths
{
N0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
b0
_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
k
_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
c0
_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
v
_lengths
{
N1
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
c0
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
v
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
d0
_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
s
_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
d0
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
s
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b1
_lengths
{
N1
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
p
_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
b1
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
p
_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
c1
_lengths
{
M0
,
N1
};
std
::
array
<
ck
::
index_t
,
2
>
o
_lengths
{
M0
,
N1
};
std
::
array
<
ck
::
index_t
,
2
>
c1
_strides
{
N1
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
o
_strides
{
N1
,
1
};
// host verify
// host verify
Tensor
<
A0
DataType
>
a0
_host
(
a0
_lengths
,
a0
_strides
);
Tensor
<
Q
DataType
>
q
_host
(
q
_lengths
,
q
_strides
);
Tensor
<
B0
DataType
>
b0
_host
(
b0
_lengths
,
b0
_strides
);
Tensor
<
K
DataType
>
k
_host
(
k
_lengths
,
k
_strides
);
Tensor
<
C0
DataType
>
c0
_host
_ref
(
c0
_lengths
,
c0
_strides
);
Tensor
<
V
DataType
>
v
_host
(
v
_lengths
,
v
_strides
);
Tensor
<
D0
DataType
>
d0
_host_ref
(
d0
_lengths
,
d0
_strides
);
Tensor
<
SMPLCompute
DataType
>
s
_host_ref
(
s
_lengths
,
s
_strides
);
Tensor
<
B1
DataType
>
b1
_host
(
b1
_lengths
,
b1
_strides
);
Tensor
<
P
DataType
>
p
_host
_ref
(
p
_lengths
,
p
_strides
);
Tensor
<
C1
DataType
>
c1
_host_ref
(
c1
_lengths
,
c1
_strides
);
Tensor
<
O
DataType
>
o
_host_ref
(
o
_lengths
,
o
_strides
);
Tensor
<
C1
DataType
>
c1
_host_dev
(
c1
_lengths
,
c1
_strides
);
Tensor
<
O
DataType
>
o
_host_dev
(
o
_lengths
,
o
_strides
);
#if
1
#if
0
ck
::
utils
::
FillUniformDistributionIntegerValue
<
A0
DataType
>
{
-
3.
f
,
3.
f
}(
a0
_host
);
ck::utils::FillUniformDistributionIntegerValue<
Q
DataType>{-3.f, 3.f}(
q
_host);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B0
DataType
>
{
-
3.
f
,
3.
f
}(
b0
_host
);
ck::utils::FillUniformDistributionIntegerValue<
K
DataType>{-3.f, 3.f}(
k
_host);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B1
DataType
>
{
-
3.
f
,
3.
f
}(
b1
_host
);
ck::utils::FillUniformDistributionIntegerValue<
V
DataType>{-3.f, 3.f}(
v
_host);
#else
#else
ck
::
utils
::
FillUniformDistribution
<
A0
DataType
>
{
-
3.
f
,
3.
f
}(
a0
_host
);
ck
::
utils
::
FillUniformDistribution
<
Q
DataType
>
{
-
3.
f
,
3.
f
}(
q
_host
);
ck
::
utils
::
FillUniformDistribution
<
B0
DataType
>
{
-
3.
f
,
3.
f
}(
b0
_host
);
ck
::
utils
::
FillUniformDistribution
<
K
DataType
>
{
-
3.
f
,
3.
f
}(
k
_host
);
ck
::
utils
::
FillUniformDistribution
<
B1
DataType
>
{
-
3.
f
,
3.
f
}(
b1
_host
);
ck
::
utils
::
FillUniformDistribution
<
V
DataType
>
{
-
3.
f
,
3.
f
}(
v
_host
);
#endif
#endif
// reference
// reference
reference_gemm
<
A0DataType
,
B0DataType
,
C0DataType
,
float
>
(
a0_host
,
b0_host
,
c0_host_ref
);
reference_gemm
<
QDataType
,
KDataType
,
SaccDataType
,
SMPLComputeDataType
>
(
reference_softmax
<
C0DataType
,
float
,
D0DataType
>
(
c0_host_ref
,
d0_host_ref
);
q_host
,
k_host
,
s_host_ref
);
reference_gemm
<
D0DataType
,
B1DataType
,
C1DataType
,
float
>
(
d0_host_ref
,
b1_host
,
c1_host_ref
);
reference_softmax
<
SMPLComputeDataType
,
SMPLComputeDataType
,
PDataType
>
(
s_host_ref
,
p_host_ref
);
reference_gemm
<
PDataType
,
VDataType
,
OaccDataType
,
ODataType
>
(
p_host_ref
,
v_host
,
o_host_ref
);
DeviceMem
a0
_buf
(
sizeof
(
A0
DataType
)
*
a0
_host
.
GetElementSpaceSize
());
DeviceMem
q
_buf
(
sizeof
(
Q
DataType
)
*
q
_host
.
GetElementSpaceSize
());
DeviceMem
b0
_buf
(
sizeof
(
B0
DataType
)
*
b0
_host
.
GetElementSpaceSize
());
DeviceMem
k
_buf
(
sizeof
(
K
DataType
)
*
k
_host
.
GetElementSpaceSize
());
DeviceMem
b1
_buf
(
sizeof
(
B1
DataType
)
*
b1
_host
.
GetElementSpaceSize
());
DeviceMem
v
_buf
(
sizeof
(
V
DataType
)
*
v
_host
.
GetElementSpaceSize
());
DeviceMem
c1
_buf
(
sizeof
(
C1
DataType
)
*
c1
_host_ref
.
GetElementSpaceSize
());
DeviceMem
o
_buf
(
sizeof
(
O
DataType
)
*
o
_host_ref
.
GetElementSpaceSize
());
a0
_buf
.
ToDevice
(
a0
_host
.
mData
.
data
());
q
_buf
.
ToDevice
(
q
_host
.
mData
.
data
());
b0
_buf
.
ToDevice
(
b0
_host
.
mData
.
data
());
k
_buf
.
ToDevice
(
k
_host
.
mData
.
data
());
b1
_buf
.
ToDevice
(
b1
_host
.
mData
.
data
());
v
_buf
.
ToDevice
(
v
_host
.
mData
.
data
());
constexpr
ck
::
index_t
kM0PerBlock
=
128
;
constexpr
ck
::
index_t
kM0PerBlock
=
128
;
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
...
@@ -102,41 +103,46 @@ int main(int argc, char* argv[])
...
@@ -102,41 +103,46 @@ int main(int argc, char* argv[])
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
constexpr
ck
::
index_t
kWarpPerCu
=
8
;
// 2 warps per SIMD
constexpr
ck
::
index_t
kWarpPerBlock
=
kBlockSize
/
warpSize
;
constexpr
ck
::
index_t
kBlockPerCu
=
kWarpPerCu
/
kWarpPerBlock
;
float
ave_time
=
float
ave_time
=
launch_kernel
<
kBlockSize
,
2
>
(
StreamConfig
{
nullptr
,
true
},
launch_kernel
<
kBlockSize
,
kBlockPerCu
>
(
StreamConfig
{
nullptr
,
true
},
GemmSoftmaxGemm
<
A0DataType
,
GemmSoftmaxGemm
<
QDataType
,
B0DataType
,
KDataType
,
Acc0DataType
,
VDataType
,
C0DataType
,
SaccDataType
,
B1DataType
,
SMPLComputeDataType
,
Acc1DataType
,
PDataType
,
C1DataType
,
OaccDataType
,
kBlockSize
,
ODataType
,
kM0PerBlock
,
kBlockSize
,
kN0PerBlock
,
kM0PerBlock
,
kK0PerBlock
,
kN0PerBlock
,
kN1PerBlock
>
{},
kK0PerBlock
,
kGridSize
,
kN1PerBlock
>
{},
kBlockSize
,
kGridSize
,
0
,
kBlockSize
,
static_cast
<
A0DataType
*>
(
a0_buf
.
GetDeviceBuffer
()),
0
,
static_cast
<
B0DataType
*>
(
b0_buf
.
GetDeviceBuffer
()),
static_cast
<
QDataType
*>
(
q_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_buf
.
GetDeviceBuffer
()),
static_cast
<
KDataType
*>
(
k_buf
.
GetDeviceBuffer
()),
static_cast
<
C1DataType
*>
(
c1_buf
.
GetDeviceBuffer
()),
static_cast
<
VDataType
*>
(
v_buf
.
GetDeviceBuffer
()),
M0
,
static_cast
<
ODataType
*>
(
o_buf
.
GetDeviceBuffer
()),
N0
,
M0
,
K0
,
N0
,
N1
,
K0
,
K0
,
// Lda0
N1
,
K0
,
// Ldb0
K0
,
// StrideQ
N0
,
// Ldb1
K0
,
// StrideK
N1
);
// Ldc1
N0
,
// StrideV
N1
);
// StrideO
c1_buf
.
FromDevice
(
c1_host_dev
.
mData
.
data
());
o_buf
.
FromDevice
(
o_host_dev
.
mData
.
data
());
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M0
*
N0
*
K0
+
std
::
size_t
(
2
)
*
M0
*
N1
*
N0
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M0
*
N0
*
K0
+
std
::
size_t
(
2
)
*
M0
*
N1
*
N0
;
std
::
size_t
num_btype
=
sizeof
(
A0
DataType
)
*
M0
*
K0
+
sizeof
(
B0
DataType
)
*
N0
*
K0
+
std
::
size_t
num_btype
=
sizeof
(
Q
DataType
)
*
M0
*
K0
+
sizeof
(
K
DataType
)
*
N0
*
K0
+
sizeof
(
B1
DataType
)
*
N1
*
N0
+
sizeof
(
C1
DataType
)
*
M0
*
N1
;
sizeof
(
V
DataType
)
*
N1
*
N0
+
sizeof
(
O
DataType
)
*
M0
*
N1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -145,8 +151,5 @@ int main(int argc, char* argv[])
...
@@ -145,8 +151,5 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
std
::
endl
;
// LogRangeAsType<float>(std::cout << "C1 dev: ", c1_host_dev.mData, ", ", 16, 20) << std::endl;
return
!
ck
::
utils
::
check_err
(
o_host_dev
,
o_host_ref
);
// LogRangeAsType<float>(std::cout << "C1 ref: ", c1_host_ref.mData, ", ", 16, 20) << std::endl;
return
!
ck
::
utils
::
check_err
(
c1_host_dev
,
c1_host_ref
);
}
}
example/91_tile_program/gemm_softmax_gemm.hpp
View file @
2837e6b3
...
@@ -17,15 +17,19 @@
...
@@ -17,15 +17,19 @@
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// C0 = A0 * B0
#include "gemm_softmax_gemm_impl.hpp"
// C1 = softmax(C0) * B1
template
<
typename
A0DataType
,
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
typename
B0DataType
,
// P[M0, N0] = Softmax(S[M0, N0])
typename
Acc0DataType
,
// O[M0, N1] = P[M0, N0] * V[N1, N0]
typename
C0DataType
,
template
<
typename
QDataType
,
typename
B1DataType
,
typename
KDataType
,
typename
Acc1DataType
,
typename
VDataType
,
typename
C1DataType
,
typename
SaccDataType
,
typename
SMPLComputeDataType
,
typename
PDataType
,
typename
OaccDataType
,
typename
ODataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kN0PerBlock
,
...
@@ -33,138 +37,21 @@ template <typename A0DataType,
...
@@ -33,138 +37,21 @@ template <typename A0DataType,
ck
::
index_t
kN1PerBlock
>
ck
::
index_t
kN1PerBlock
>
struct
GemmSoftmaxGemm
struct
GemmSoftmaxGemm
{
{
// block gemm0 pipeline
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
using
BlockGemm0Pipeline
=
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2
<
const
KDataType
*
k_ptr
,
ck
::
tile_program
::
block
::
BlockGemmPipelineProblem
<
const
VDataType
*
v_ptr
,
A0DataType
,
ODataType
*
o_ptr
,
B0DataType
,
const
ck
::
index_t
M0
,
Acc0DataType
,
const
ck
::
index_t
N0
,
kBlockSize
,
const
ck
::
index_t
K0
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
>>
,
const
ck
::
index_t
N1
,
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
;
const
ck
::
index_t
StrideQ
,
const
ck
::
index_t
StrideK
,
// block gemm1
const
ck
::
index_t
StrideV
,
using
BlockGemm1
=
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1
<
const
ck
::
index_t
StrideO
)
const
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1Problem
<
C0DataType
,
B1DataType
,
Acc1DataType
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
kN0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
#if 0
// 2d
__device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
{
using
namespace
ck
;
using
namespace
ck
;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#else
// fake XOR
__device__
static
constexpr
auto
MakeB1LdsBlockDescriptor
()
{
using
namespace
ck
;
using
BDataType
=
B1DataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
kNPerBlock
/
2
,
2
,
kKPerBlock
),
Number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
kNPerBlock
/
2
,
kKPerBlock
),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kNPerBlock
/
2
,
2
)),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
__device__
static
constexpr
auto
MakeB1DramTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
BDataType
=
B1DataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
using
namespace
ck
;
return
math
::
max
(
BlockGemm0Pipeline
::
GetStaticLdsSize
(),
static_cast
<
index_t
>
(
MakeB1LdsBlockDescriptor
().
GetElementSpaceSize
()
*
sizeof
(
B1DataType
)));
}
__device__
void
operator
()(
const
A0DataType
*
p_a0
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
C1DataType
*
p_c1
,
ck
::
index_t
M0
,
ck
::
index_t
N0
,
ck
::
index_t
K0
,
ck
::
index_t
N1
,
ck
::
index_t
Lda0
,
ck
::
index_t
Ldb0
,
ck
::
index_t
Ldb1
,
ck
::
index_t
Ldc1
)
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
namespace
ck
::
tile_program
::
block
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// FIXME: assume layout A0[M0, K0], B0[N0, K0], B1[N1, N0], C1[M0, N1]
const
auto
a0_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_a0
,
make_tuple
(
M0
,
K0
),
make_tuple
(
Lda0
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
b0_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_b0
,
make_tuple
(
N0
,
K0
),
make_tuple
(
Ldb0
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
b1_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_b1
,
make_tuple
(
N1
,
N0
),
make_tuple
(
Ldb1
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
// divide problem
// divide problem
const
auto
num_tile_n1
=
N1
/
kN1PerBlock
;
const
auto
num_tile_n1
=
N1
/
kN1PerBlock
;
...
@@ -176,215 +63,33 @@ struct GemmSoftmaxGemm
...
@@ -176,215 +63,33 @@ struct GemmSoftmaxGemm
const
auto
iM0
=
__builtin_amdgcn_readfirstlane
(
id_tile_m
*
kM0PerBlock
);
const
auto
iM0
=
__builtin_amdgcn_readfirstlane
(
id_tile_m
*
kM0PerBlock
);
const
auto
iN1
=
__builtin_amdgcn_readfirstlane
(
id_tile_n
*
kN1PerBlock
);
const
auto
iN1
=
__builtin_amdgcn_readfirstlane
(
id_tile_n
*
kN1PerBlock
);
__shared__
char
p_smem_char
[
GetStaticLdsSize
()];
const
auto
kernel_impl
=
GemmSoftmaxGemmImpl
<
QDataType
,
KDataType
,
// A0 DRAM block window
VDataType
,
auto
a0_dram_block_window
=
make_tile_window
(
SaccDataType
,
a0_dram_grid
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
iM0
,
0
});
SMPLComputeDataType
,
PDataType
,
// B0 DRAM block window
OaccDataType
,
auto
b0_dram_block_window
=
make_tile_window
(
ODataType
,
b0_dram_grid
,
make_tuple
(
Number
<
kN0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
0
,
0
});
kBlockSize
,
kM0PerBlock
,
// Block GEMM0 pipeline
kN0PerBlock
,
constexpr
auto
block_gemm0_pipeline
=
BlockGemm0Pipeline
{};
kK0PerBlock
,
kN1PerBlock
>
{};
// B1 DRAM window
auto
b1_dram_block_window
=
kernel_impl
(
q_ptr
,
make_tile_window
(
b1_dram_grid
,
k_ptr
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
v_ptr
,
{
iN1
,
0
},
o_ptr
,
MakeB1DramTileDistribution
());
M0
,
N0
,
// B1 LDS tensor view: occupies the same LDS allocation as block_gemm0_pipeline
K0
,
auto
b1_lds_block
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
N1
,
reinterpret_cast
<
B1DataType
*>
(
p_smem_char
),
MakeB1LdsBlockDescriptor
());
StrideQ
,
StrideK
,
auto
b1_lds_block_window
=
make_tile_window
(
StrideV
,
b1_lds_block
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
0
,
0
});
StrideO
,
iM0
,
// Bock GEMM1
iN1
);
constexpr
auto
block_gemm1
=
BlockGemm1
{};
// Acc0 tile
using
Acc0BlockTileType
=
decltype
(
block_gemm0_pipeline
(
a0_dram_block_window
,
b0_dram_block_window
,
0
,
nullptr
));
// Acc1 tile
auto
acc1_block_tile
=
decltype
(
block_gemm1
(
tile_elementwise_in
(
type_convert
<
C0DataType
,
Acc0DataType
>
,
Acc0BlockTileType
{}),
b1_dram_block_window
)){};
const
auto
f_max
=
[](
auto
v0
,
auto
v1
)
{
return
max
(
v0
,
v1
);
};
const
auto
f_sum
=
[](
auto
v0
,
auto
v1
)
{
return
v0
+
v1
;
};
// init Acc1
tile_elementwise_inout
([](
auto
&
acc1
)
{
acc1
=
0
;
},
acc1_block_tile
);
// m, l tile
auto
m
=
decltype
(
block_tile_reduce
<
Acc0DataType
>
(
Acc0BlockTileType
{},
Sequence
<
1
>
{},
f_max
,
Acc0DataType
{
0
})){};
// init m, l
auto
l
=
make_static_distributed_tensor
<
Acc0DataType
>
(
m
.
GetTileDistribution
());
tile_elementwise_inout
([](
auto
&
m_v
)
{
m_v
=
NumericLimits
<
Acc0DataType
>::
Lowest
();
},
m
);
tile_elementwise_inout
([](
auto
&
l_v
)
{
l_v
=
0
;
},
l
);
index_t
iN0
=
0
;
do
{
// S[i][j] = Q[i] * K[j]
const
auto
acc0_block_tile
=
block_gemm0_pipeline
(
a0_dram_block_window
,
b0_dram_block_window
,
K0
/
kK0PerBlock
,
p_smem_char
);
// rowmax(S[i][j])
auto
m_local
=
block_tile_reduce
<
Acc0DataType
>
(
acc0_block_tile
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
Acc0DataType
>::
Lowest
());
block_tile_reduce_sync
(
m_local
,
f_max
);
// m[i][j-1]
const
auto
m_old
=
m
;
// m[i][j]
tile_elementwise_inout
(
[](
auto
&
m_v
,
auto
m_old_v
,
auto
m_local_v
)
{
m_v
=
max
(
m_old_v
,
m_local_v
);
},
m
,
m_old
,
m_local
);
// P[i][j]
auto
p
=
make_static_distributed_tensor
<
Acc0DataType
>
(
acc0_block_tile
.
GetTileDistribution
());
constexpr
auto
p_spans
=
decltype
(
p
)
::
GetDistributedSpans
();
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
m_v
=
m
.
GetElementFromTileDistributedIndices
(
i_idx
);
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
s_v
=
acc0_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
const
auto
p_v
=
math
::
exp
(
s_v
-
m_v
);
p
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
p_v
);
});
});
// rowsum(P[i][j])
auto
rowsum_p
=
block_tile_reduce
<
Acc0DataType
>
(
p
,
Sequence
<
1
>
{},
f_sum
,
Acc0DataType
{
0
});
block_tile_reduce_sync
(
rowsum_p
,
f_sum
);
// l[i][j], O[i][j]
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
m_old_v
=
m_old
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
m_v
=
m
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
l_old_v
=
l
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
tmp
=
math
::
exp
(
m_old_v
-
m_v
);
const
auto
tmp2
=
1
/
tmp
;
auto
l_v
=
tmp
*
l_old_v
+
rowsum_p
.
GetElementFromTileDistributedIndices
(
i_idx
);
l
.
SetElementFromTileDistributedIndices
(
i_idx
,
l_v
);
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// O[i][j]
const
auto
o_old_v
=
acc1_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
#if 0 // debug
// this use the same equation from FA v2 paper, but produce -nan
const auto o_v = o_old_v * tmp2;
#elif
1
// this use different equation from FA v2 paper, but produce correct result
(
void
)
tmp2
;
const
auto
o_v
=
o_old_v
*
tmp
;
#endif
acc1_block_tile
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
o_v
);
});
});
// type cast p into a1
const
auto
c0_block_tile
=
tile_elementwise_in
(
type_convert
<
C0DataType
,
Acc0DataType
>
,
p
);
// Block GEMM1: acc1 += c0 * b1
{
// load b1
const
auto
b1_block_tile
=
load_tile
(
b1_dram_block_window
);
// wait for block gemm0 pipeline to finish
block_sync_lds
();
store_tile
(
b1_lds_block_window
,
b1_block_tile
);
// wait for store_tile to finish
block_sync_lds
();
// acc1 += c0 * b1
block_gemm1
(
acc1_block_tile
,
c0_block_tile
,
b1_lds_block_window
);
// wait for block gemm1 to finish
block_sync_lds
();
}
// move tile windows
move_tile_window
(
b0_dram_block_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
b1_dram_block_window
,
{
0
,
kN0PerBlock
});
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
// o[i][J-1]
constexpr
auto
o_spans
=
decltype
(
acc1_block_tile
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
l_v
=
l
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
tmp
=
1
/
l_v
;
sweep_tile_span
(
o_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
o_v
=
acc1_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
const
auto
o_new_v
=
o_v
*
tmp
;
acc1_block_tile
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
o_new_v
);
});
});
// type cast acc1 into c1
const
auto
c1_block_tile
=
tile_elementwise_in
(
type_convert
<
C1DataType
,
Acc1DataType
>
,
acc1_block_tile
);
// store c1
auto
c1_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_c1
,
make_tuple
(
M0
,
N1
),
make_tuple
(
Ldc1
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
auto
c1_dram_window
=
make_tile_window
(
c1_dram_grid
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kN1PerBlock
>
{}),
{
iM0
,
iN1
},
c1_block_tile
.
GetTileDistribution
());
store_tile
(
c1_dram_window
,
c1_block_tile
);
}
}
};
};
example/91_tile_program/gemm_softmax_gemm_impl.hpp
0 → 100644
View file @
2837e6b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template
<
typename
QDataType
,
typename
KDataType
,
typename
VDataType
,
typename
SaccDataType
,
typename
SMPLComputeDataType
,
typename
PDataType
,
typename
OaccDataType
,
typename
ODataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kN1PerBlock
>
struct
GemmSoftmaxGemmImpl
{
// block gemm0 pipeline
using
BlockGemm0Pipeline
=
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2
<
ck
::
tile_program
::
block
::
BlockGemmPipelineProblem
<
QDataType
,
KDataType
,
SaccDataType
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
;
// block gemm1
using
BlockGemm1
=
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1
<
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1Problem
<
PDataType
,
VDataType
,
OaccDataType
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
kN0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
#if 0
// 2d
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_desc;
}
#else
// fake XOR
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
using
namespace
ck
;
using
BDataType
=
VDataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
auto
b_lds_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
kNPerBlock
/
2
,
2
,
kKPerBlock
),
Number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
kNPerBlock
/
2
,
kKPerBlock
),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
b_lds_desc_n_k
=
transform_tensor_descriptor
(
b_lds_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kNPerBlock
/
2
,
2
)),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_desc_n_k
;
}
#endif
__device__
static
constexpr
auto
MakeVDramTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
BDataType
=
VDataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
using
namespace
ck
;
return
math
::
max
(
BlockGemm0Pipeline
::
GetStaticLdsSize
(),
static_cast
<
index_t
>
(
MakeVLdsBlockDescriptor
().
GetElementSpaceSize
()
*
sizeof
(
VDataType
)));
}
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
const
KDataType
*
k_ptr
,
const
VDataType
*
v_ptr
,
ODataType
*
o_ptr
,
const
ck
::
index_t
M0
,
const
ck
::
index_t
N0
,
const
ck
::
index_t
K0
,
const
ck
::
index_t
N1
,
const
ck
::
index_t
StrideQ
,
const
ck
::
index_t
StrideK
,
const
ck
::
index_t
StrideV
,
const
ck
::
index_t
StrideO
,
const
ck
::
index_t
iM0
,
const
ck
::
index_t
iN1
)
const
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
namespace
ck
::
tile_program
::
block
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// allocate LDS
__shared__
char
smem_ptr
[
GetStaticLdsSize
()];
// Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1]
const
auto
q_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
q_ptr
,
make_tuple
(
M0
,
K0
),
make_tuple
(
StrideQ
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
k_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
k_ptr
,
make_tuple
(
N0
,
K0
),
make_tuple
(
StrideK
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
v_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
v_ptr
,
make_tuple
(
N1
,
N0
),
make_tuple
(
StrideV
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
iM0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
Number
<
kN0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
0
,
0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
iN1
,
0
},
MakeVDramTileDistribution
());
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto
v_lds
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
MakeVLdsBlockDescriptor
());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
0
,
0
});
// Block GEMM0 pipeline and Block GEMM1
constexpr
auto
gemm0_pipeline
=
BlockGemm0Pipeline
{};
constexpr
auto
gemm1
=
BlockGemm1
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SaccBlockTileType
=
decltype
(
gemm0_pipeline
(
q_dram_window
,
k_dram_window
,
0
,
nullptr
));
using
SBlockTileType
=
decltype
(
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
SaccBlockTileType
{}));
using
PBlockTileType
=
decltype
(
tile_elementwise_in
(
type_convert
<
PDataType
,
SaccDataType
>
,
SaccBlockTileType
{}));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm1
(
PBlockTileType
{},
v_dram_window
));
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
0
;
},
o_acc
);
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
NumericLimits
<
SMPLComputeDataType
>::
Lowest
();
},
m
);
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
0
;
},
l
);
// loop over Column of S (J loop)
index_t
iN0
=
0
;
do
{
// Sacc{j} = Q * K{j}
const
auto
s_acc
=
gemm0_pipeline
(
q_dram_window
,
k_dram_window
,
K0
/
kK0PerBlock
,
smem_ptr
);
// S{j}
const
auto
s
=
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
s_acc
);
// m_local = rowmax(S{j})
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
SMPLComputeDataType
>::
Lowest
());
block_tile_reduce_sync
(
m_local
,
f_max
);
// m{j-1}
const
auto
m_old
=
m
;
// m{j}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// Pcompute{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
GetTileDistribution
());
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
GetDistributedSpans
();
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
p_compute
(
i_j_idx
)
=
math
::
exp
(
s
[
i_j_idx
]
-
m
[
i_idx
]);
});
});
// rowsum(Pcompute{j})
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
Sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
block_tile_reduce_sync
(
rowsum_p
,
f_sum
);
// l{j}, Oacc{j}
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
math
::
exp
(
m_old
[
i_idx
]
-
m
[
i_idx
]);
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
// type cast Pcompute{j} into P{j}
const
auto
p
=
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
// Block GEMM1: Oacc{j} += P{j} * V{j}
{
// load V{j}
const
auto
v
=
load_tile
(
v_dram_window
);
// wait for gemm0 pipeline to finish
block_sync_lds
();
store_tile
(
v_lds_window
,
v
);
// wait for store_tile to finish
block_sync_lds
();
// Oacc{j} += P{j} * V{j}
gemm1
(
o_acc
,
p
,
v_lds_window
);
// wait for gemm1 to finish
block_sync_lds
();
}
// move tile windows
move_tile_window
(
k_dram_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
v_dram_window
,
{
0
,
kN0PerBlock
});
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
// O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
1
/
l
[
i_idx
];
sweep_tile_span
(
o_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
// type cast Oacc into O
const
auto
o
=
tile_elementwise_in
(
type_convert
<
ODataType
,
OaccDataType
>
,
o_acc
);
// O DRAM and O DRAM window
auto
o_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
o_ptr
,
make_tuple
(
M0
,
N1
),
make_tuple
(
StrideO
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
auto
o_dram_window
=
make_tile_window
(
o_dram
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kN1PerBlock
>
{}),
{
iM0
,
iN1
},
o
.
GetTileDistribution
());
// store O
store_tile
(
o_dram_window
,
o
);
}
};
example/91_tile_program/reference_batched_gemm.hpp
0 → 100644
View file @
2837e6b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
void
reference_batched_gemm
(
const
Tensor
<
ADataType
>&
a_b_m_k
,
const
Tensor
<
BDataType
>&
b_b_n_k
,
Tensor
<
CDataType
>&
c_b_m_n
)
{
const
int
N
=
b_b_n_k
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
b_b_n_k
.
mDesc
.
GetLengths
()[
2
];
auto
f
=
[
&
](
auto
batch
,
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
=
a_b_m_k
(
batch
,
m
,
k
);
BDataType
v_b
=
b_b_n_k
(
batch
,
n
,
k
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
c_b_m_n
(
batch
,
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_acc
);
}
};
make_ParallelTensorFunctor
(
f
,
c_b_m_n
.
mDesc
.
GetLengths
()[
0
],
c_b_m_n
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
example/91_tile_program/reference_batched_softmax.hpp
0 → 100644
View file @
2837e6b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
>
void
reference_batched_softmax
(
const
Tensor
<
ADataType
>&
a_b_m_n
,
Tensor
<
BDataType
>&
b_b_m_n
)
{
const
int
N
=
a_b_m_n
.
mDesc
.
GetLengths
()[
2
];
auto
f
=
[
&
](
auto
batch
,
auto
m
)
{
AccDataType
v_max
=
ck
::
NumericLimits
<
ADataType
>::
Lowest
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_b_m_n
(
batch
,
m
,
n
);
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
}
AccDataType
v_exp_sum
=
0
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_b_m_n
(
batch
,
m
,
n
);
v_exp_sum
+=
ck
::
math
::
exp
(
v_a
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_b_m_n
(
batch
,
m
,
n
);
b_b_m_n
(
batch
,
m
,
n
)
=
ck
::
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
}
};
make_ParallelTensorFunctor
(
f
,
b_b_m_n
.
mDesc
.
GetLengths
()[
0
],
b_b_m_n
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
example/91_tile_program/reference_gemm.hpp
View file @
2837e6b3
...
@@ -6,28 +6,30 @@
...
@@ -6,28 +6,30 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
template
<
typename
ADataType
,
typename
BDataType
,
typename
C
DataType
,
typename
Acc
DataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
Acc
DataType
,
typename
C
DataType
>
void
reference_gemm
(
const
Tensor
<
ADataType
>&
a_m_k
,
void
reference_gemm
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_n_k
,
const
Tensor
<
BDataType
>&
b_n_k
,
Tensor
<
CDataType
>&
c_m_n
)
Tensor
<
CDataType
>&
c_m_n
)
{
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
const
int
K
=
a_m
_k
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
b_n
_k
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
ADataType
v_a
=
a_m_k
(
m
,
k
);
AccDataType
v_acc
=
0
;
BDataType
v_b
=
b_n_k
(
n
,
k
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
for
(
int
k
=
0
;
k
<
K
;
++
k
)
}
{
ADataType
v_a
=
a_m_k
(
m
,
k
);
BDataType
v_b
=
b_n_k
(
n
,
k
);
c_m_n
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_acc
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
c_m_n
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_acc
);
}
};
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
make_ParallelTensorFunctor
(
f
,
c_m_n
.
mDesc
.
GetLengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
c_m_n
.
mDesc
.
GetLengths
()[
0
],
c_m_n
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
}
example/91_tile_program/softmax.hpp
View file @
2837e6b3
...
@@ -143,23 +143,20 @@ struct Softmax
...
@@ -143,23 +143,20 @@ struct Softmax
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
const
auto
v_max
=
max_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
)
;
const
auto
v_max
=
max_block_tensor
[
m_idx
]
;
AccDataType
v_exp_sum
=
AccDataType
v_exp_sum
=
exp_sum_block_tensor
[
m_idx
];
exp_sum_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
v_a
=
a_block_tensor
.
GetElementFromTileDistributedIndices
(
m_n_idx
);
const
auto
v_a
=
a_block_tensor
[
m_n_idx
];
(
void
)
v_max
;
// exp and sum
// exp and sum
v_exp_sum
+=
math
::
exp
(
v_a
-
v_max
);
v_exp_sum
+=
math
::
exp
(
v_a
-
v_max
);
});
});
exp_sum_block_tensor
.
SetElementFromTileDistributedIndices
(
m_idx
,
v_exp_sum
)
;
exp_sum_block_tensor
(
m_idx
)
=
v_exp_sum
;
});
});
move_tile_window
(
a_block_window
,
{
0
,
kNPerBlock
});
move_tile_window
(
a_block_window
,
{
0
,
kNPerBlock
});
...
@@ -196,21 +193,20 @@ struct Softmax
...
@@ -196,21 +193,20 @@ struct Softmax
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
const
auto
v_max
=
max_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
)
;
const
auto
v_max
=
max_block_tensor
[
m_idx
]
;
const
auto
v_exp_sum
=
const
auto
v_exp_sum
=
exp_sum_block_tensor
[
m_idx
];
exp_sum_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
v_a
=
a_block_tensor
.
GetElementFromTileDistributedIndices
(
m_n_idx
)
;
const
auto
v_a
=
a_block_tensor
[
m_n_idx
]
;
// exp
// exp
const
BDataType
v_b
=
const
BDataType
v_b
=
type_convert
<
BDataType
>
(
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
);
type_convert
<
BDataType
>
(
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
);
b_block_tensor
.
SetElementFromTileDistributedIndices
(
m_n_idx
,
v_b
)
;
b_block_tensor
(
m_n_idx
)
=
v_b
;
});
});
});
});
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
2837e6b3
...
@@ -160,11 +160,11 @@ float launch_kernel(const StreamConfig& stream_config,
...
@@ -160,11 +160,11 @@ float launch_kernel(const StreamConfig& stream_config,
KernelImpl
kernel_impl
,
KernelImpl
kernel_impl
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds
_byte
,
std
::
size_t
dynamic_smem
_byte
,
Args
...
args
)
Args
...
args
)
{
{
const
auto
kernel
=
kernel_wrapper
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
const
auto
kernel
=
kernel_wrapper
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
block_dim
,
lds
_byte
,
kernel_impl
,
args
...);
stream_config
,
kernel
,
grid_dim
,
block_dim
,
dynamic_smem
_byte
,
kernel_impl
,
args
...);
}
}
include/ck/tile_program/block_tile/block_reduce.hpp
View file @
2837e6b3
...
@@ -161,18 +161,18 @@ __device__ void block_tile_reduce(AccDistributedTensor_& acc_tensor,
...
@@ -161,18 +161,18 @@ __device__ void block_tile_reduce(AccDistributedTensor_& acc_tensor,
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
acc_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
constexpr
auto
acc_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
acc
=
acc_tensor
.
GetElementFromTileDistributedIndices
(
acc_dstr_idx
)
;
auto
acc
=
acc_tensor
[
acc_dstr_idx
]
;
// FIXME
// FIXME
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
const
auto
in
=
in_tensor
.
GetElementFromTileDistributedIndices
(
in_dstr_idx
)
;
const
auto
in
=
in_tensor
[
in_dstr_idx
]
;
acc
=
reduce_func
(
acc
,
in
);
acc
=
reduce_func
(
acc
,
in
);
});
});
acc_tensor
.
SetElementFromTileDistributedIndices
(
acc_dstr_idx
,
acc
)
;
acc_tensor
(
acc_dstr_idx
)
=
acc
;
});
});
#endif
#endif
}
}
...
...
include/ck/tile_program/tile/static_distributed_tensor.hpp
View file @
2837e6b3
...
@@ -105,6 +105,31 @@ struct StaticDistributedTensor
...
@@ -105,6 +105,31 @@ struct StaticDistributedTensor
});
});
}
}
template
<
typename
TileDistributedIndices
>
__host__
__device__
constexpr
const
DataType
&
operator
[](
TileDistributedIndices
)
const
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
GetTileDistribution
().
GetYIndicesFromDistributedIndices
(
TileDistributedIndices
{});
return
thread_buf_
[
Number
<
ThreadTensorDesc
{}.
CalculateOffset
(
y_idx
)
>
{}];
}
template
<
typename
TileDistributedIndices
>
__host__
__device__
constexpr
DataType
&
operator
()(
TileDistributedIndices
)
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
GetTileDistribution
().
GetYIndicesFromDistributedIndices
(
TileDistributedIndices
{});
return
thread_buf_
(
Number
<
ThreadTensorDesc
{}.
CalculateOffset
(
y_idx
)
>
{});
}
#if 0
template <index_t... Ys>
template <index_t... Ys>
__host__ __device__ auto GetElementFromYsIndex(Sequence<Ys...> idx_ys) const
__host__ __device__ auto GetElementFromYsIndex(Sequence<Ys...> idx_ys) const
{
{
...
@@ -116,7 +141,6 @@ struct StaticDistributedTensor
...
@@ -116,7 +141,6 @@ struct StaticDistributedTensor
{
{
thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) = v;
thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) = v;
}
}
template <typename TileDistributedIndices>
template <typename TileDistributedIndices>
__host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const
__host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const
{
{
...
@@ -139,6 +163,7 @@ struct StaticDistributedTensor
...
@@ -139,6 +163,7 @@ struct StaticDistributedTensor
return SetElementFromYsIndex(y_idx, v);
return SetElementFromYsIndex(y_idx, v);
}
}
#endif
//
//
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
kThreadElementSpaceSize
,
true
>
thread_buf_
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
kThreadElementSpaceSize
,
true
>
thread_buf_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment