Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
09d4c3a4
"tests/vscode:/vscode.git/clone" did not exist on "72182747a17231075b26768d37694781bd992daf"
Commit
09d4c3a4
authored
Oct 01, 2024
by
illsilin
Browse files
merge from public repo
parents
171ed358
8e4c3fb1
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
649 additions
and
202 deletions
+649
-202
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+14
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+131
-4
include/ck_tile/host/reference/reference_im2col.hpp
include/ck_tile/host/reference/reference_im2col.hpp
+117
-45
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+21
-23
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+6
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+50
-45
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+27
-24
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+2
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+6
-5
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+4
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+191
-0
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+34
-0
No files found.
include/ck_tile/host/host_tensor.hpp
View file @
09d4c3a4
...
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
...
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
)
{
os
<<
"dim "
<<
desc
.
get_num_of_dimension
()
<<
", "
;
os
<<
"lengths {"
;
LogRange
(
os
,
desc
.
get_lengths
(),
", "
);
os
<<
"}, "
;
os
<<
"strides {"
;
LogRange
(
os
,
desc
.
get_strides
(),
", "
);
os
<<
"}"
;
return
os
;
}
private:
private:
std
::
vector
<
std
::
size_t
>
mLens
;
std
::
vector
<
std
::
size_t
>
mLens
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
09d4c3a4
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
#include <thread>
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -13,6 +14,9 @@ template <typename ADataType,
...
@@ -13,6 +14,9 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
typename
ACCElementOp
=
ck_tile
::
identity
>
...
@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
K
=
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
...
@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
...
@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
}
}
};
};
make_ParallelTensorFunctor
(
f
,
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
c_m_n
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
CDataType
*
C
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
strideA
,
ck_tile
::
index_t
strideB
,
ck_tile
::
index_t
strideC
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
idx
/
N
;
// Compute row index
int
col
=
idx
%
N
;
// Compute column index
if
(
row
<
M
&&
col
<
N
)
{
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_im2col.hpp
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -9,53 +9,125 @@
...
@@ -9,53 +9,125 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
T
>
template
<
typename
InDataType
,
typename
OutDataType
,
index_t
NDimSpatial
>
CK_TILE_HOST
void
reference_im2col
(
HostTensor
<
T
>&
in_mtx_host_ref
,
CK_TILE_HOST
void
reference_im2col
(
const
HostTensor
<
InDataType
>&
in_host
,
const
HostTensor
<
T
>&
in_host
,
HostTensor
<
OutDataType
>&
out_host
,
int
/*N*/
,
const
ck_tile
::
conv
::
ConvParam
&
conv_params
)
int
/*K*/
,
int
C
,
int
/*Y*/
,
int
X
,
int
Hi
,
int
Wi
,
int
Ho
,
int
Wo
,
int
ConvStrideH
,
int
ConvStrideW
,
int
ConvDilationH
,
int
ConvDilationW
,
int
InLeftPadH
,
int
InLeftPadW
,
int
/*InRightPadH*/
,
int
/*InRightPadW*/
)
{
{
int
GemmM
=
in_mtx_host_ref
.
get_lengths
()[
0
];
const
long_index_t
G
=
in_host
.
get_lengths
()[
0
];
int
GemmK
=
in_mtx_host_ref
.
get_lengths
()[
1
];
const
long_index_t
N
=
in_host
.
get_lengths
()[
1
];
const
long_index_t
C
=
in_host
.
get_lengths
()[
2
];
for
(
int
gemm_m
=
0
;
gemm_m
<
GemmM
;
++
gemm_m
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
int
mtmp
=
gemm_m
;
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
0
];
int
n
=
mtmp
/
(
Ho
*
Wo
);
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
mtmp
-=
n
*
Ho
*
Wo
;
long_index_t
row
=
n
*
Wo
+
wo
;
int
ho
=
mtmp
/
Wo
;
long_index_t
column
=
0
;
int
wo
=
mtmp
-
ho
*
Wo
;
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
int
gemm_k
=
0
;
gemm_k
<
GemmK
;
++
gemm_k
)
{
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
0
])
+
int
ktmp
=
gemm_k
;
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
0
])
-
int
y
=
ktmp
/
(
X
*
C
);
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
ktmp
-=
y
*
X
*
C
;
int
x
=
ktmp
/
C
;
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
int
c
=
ktmp
-
x
*
C
;
{
if
(
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
3
])
int
hi
=
y
*
ConvDilationH
+
ho
*
ConvStrideH
-
InLeftPadH
;
{
int
wi
=
x
*
ConvDilationW
+
wo
*
ConvStrideW
-
InLeftPadW
;
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
bool
inbound
=
(
hi
>=
0
&&
hi
<
Hi
&&
wi
>=
0
&&
wi
<
Wi
);
}
column
++
;
in_mtx_host_ref
(
gemm_m
,
gemm_k
)
=
inbound
?
in_host
(
n
,
hi
,
wi
,
c
)
:
0
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
3
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
4
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
long_index_t
Do
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
1
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
z
=
0
;
z
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
long_index_t
>
(
d_o
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
z
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
2
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
2
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
2
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
type_convert
<
std
::
size_t
>
(
di
)
<
in_host
.
get_lengths
()[
3
]
&&
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
4
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
5
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
di
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Do
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
}
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
09d4c3a4
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
09d4c3a4
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_q
,
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
,
batch_stride_lse_acc
};
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
ck_tile
::
index_t
split_stride_o_acc
)
{
{
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
return
kargs
;
}
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_o
=
0
;
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
query_start
;
batch_offset_lse
=
query_start
;
}
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
}
else
else
{
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
return
transform_tensor_view
(
o_acc_dram_view
,
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
else
else
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
09d4c3a4
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
_
,
nhead
,
batch_size
_
);
batch_size
);
}
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_block
=
blockIdx
.
x
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
09d4c3a4
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
};
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
,
batch_stride_o_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_v
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
,
num_splits
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
long_index_t
batch_offset_o_acc
=
0
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kIsGroupMode
)
if
constexpr
(
kIsGroupMode
)
{
{
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
stride_o_acc
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
09d4c3a4
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
nhead
*
num_splits
,
batch_size
);
batch_size
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
09d4c3a4
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
},
s_acc
,
s_acc
,
bias_s_tile
);
bias_s_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
dp_acc
=
SPGradBlockTileType
{};
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
auto
ds
=
SPGradBlockTileType
{};
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 7, SGrad@K^T Gemm4
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
clear_tile
(
dq_acc
);
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
__builtin_amdgcn_sched_barrier
(
0
);
// Results Scale
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
09d4c3a4
...
@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
...
@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
...
@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -1722,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1722,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
0
>
()
{
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
// Comp: Q x K
...
@@ -1754,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1754,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
1
>
()
{
{
// Mem: Q^T LDS load
// Mem: Q^T LDS load
// Comp: OGrad x V
// Comp: OGrad x V
...
@@ -1772,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1772,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
2
>
()
{
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
// Comp: PT x OGrad
...
@@ -1791,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1791,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
3
>
()
{
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
// Comp: SGradT x QT
...
@@ -1825,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1825,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
4
>
()
{
{
// Mem: SGrad, OGrad, D LDS load.
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
// Comp: SGrad x KT
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
09d4c3a4
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
// lse_acc tile in LDS
// lse_acc tile in LDS
...
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
const
index_t
padded_
max_
seqlen_q
=
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
kM0
;
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
{
...
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
});
}
}
move_tile_window
(
o_acc_dram_window
,
{
padded_
max_
seqlen_q
,
0
});
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
}
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
...
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
return
operator
()(
lse_acc_dram_block_window
,
return
operator
()(
lse_acc_dram_block_window
,
...
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
num_splits
,
max_
seqlen_q
,
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
09d4c3a4
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if
masked and
no work to do
.
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
{
const
index_t
original_num_total_loop
=
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
09d4c3a4
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if
masked and
no work to do
.
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
09d4c3a4
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
// check early exit
if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
09d4c3a4
...
@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/gemm.hpp
View file @
09d4c3a4
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
09d4c3a4
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
...
@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using
BlockGemmARegB
S
memCRegImpl
=
BlockGemmARegB
S
memCRegV1
<
using
BlockGemmARegB
G
memCRegImpl
=
BlockGemmARegB
G
memCRegV1
<
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegB
S
memCRegV1DefaultPolicy
>
;
BlockGemmARegB
G
memCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
...
@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
...
@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds
();
block_sync_lds
();
// block GEMM
// block GEMM
BlockGemmARegB
S
memCRegImpl
{}(
c_block_tensor
,
a_block_tensor
,
b_block_smem_window
);
BlockGemmARegB
G
memCRegImpl
{}(
c_block_tensor
,
a_block_tensor
,
b_block_smem_window
);
}
}
// C = A * B
// C = A * B
...
@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
...
@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds
();
block_sync_lds
();
// block GEMM
// block GEMM
return
BlockGemmARegB
S
memCRegImpl
{}(
a_block_tensor
,
b_block_smem_window
);
return
BlockGemmARegB
G
memCRegImpl
{}(
a_block_tensor
,
b_block_smem_window
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
09d4c3a4
...
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
}
else
{
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
}
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
GemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmCommonKargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
float
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
}
else
{
// Default NK layout
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
auto
ABlockWindow
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
auto
BBlockWindow
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
auto
acc
=
GemmPipeline
{}(
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
{
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
ck_tile
::
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment