Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3aec6f03
Unverified
Commit
3aec6f03
authored
Jan 13, 2025
by
arai713
Committed by
GitHub
Jan 13, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
cdfceb0a
5d671a5f
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
693 additions
and
147 deletions
+693
-147
Jenkinsfile
Jenkinsfile
+41
-6
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+1
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+5
-2
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+0
-2
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+21
-0
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+13
-0
example/ck_tile/03_gemm/script/run_full_test.sh
example/ck_tile/03_gemm/script/run_full_test.sh
+22
-2
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+1
-1
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+35
-0
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+19
-35
example/ck_tile/05_reduce/reduce.cpp
example/ck_tile/05_reduce/reduce.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+13
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+5
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+72
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+308
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+11
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+28
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+38
-11
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+46
-48
No files found.
Jenkinsfile
View file @
3aec6f03
...
@@ -326,12 +326,38 @@ def cmake_build(Map conf=[:]){
...
@@ -326,12 +326,38 @@ def cmake_build(Map conf=[:]){
if
(
package_build
==
true
&&
(
env
.
BRANCH_NAME
==
"develop"
||
env
.
BRANCH_NAME
==
"amd-master"
))
{
if
(
package_build
==
true
&&
(
env
.
BRANCH_NAME
==
"develop"
||
env
.
BRANCH_NAME
==
"amd-master"
))
{
archiveArtifacts
artifacts:
"build/*.deb"
,
allowEmptyArchive:
true
,
fingerprint:
true
archiveArtifacts
artifacts:
"build/*.deb"
,
allowEmptyArchive:
true
,
fingerprint:
true
}
}
//check the node gpu architecture
def
arch_type
=
0
sh
'rocminfo | tee rocminfo.log'
if
(
runShell
(
'grep -n "gfx90a" rocminfo.log'
)
){
arch_type
=
1
}
else
if
(
runShell
(
'grep -n "gfx942" rocminfo.log'
)
)
{
arch_type
=
2
}
if
(
params
.
RUN_CK_TILE_FMHA_TESTS
){
if
(
params
.
RUN_CK_TILE_FMHA_TESTS
){
try
{
try
{
archiveArtifacts
"perf_fmha_fwd_*.log"
archiveArtifacts
"perf_fmha_*.log"
archiveArtifacts
"perf_fmha_bwd_*.log"
if
(
arch_type
==
1
){
stash
includes:
"perf_fmha_**_gfx942.log"
,
name:
"perf_fmha_log_gfx942"
stash
includes:
"perf_fmha_**_gfx90a.log"
,
name:
"perf_fmha_log_gfx90a"
stash
includes:
"perf_fmha_**_gfx90a.log"
,
name:
"perf_fmha_log_gfx90a"
}
else
if
(
arch_type
==
2
){
stash
includes:
"perf_fmha_**_gfx942.log"
,
name:
"perf_fmha_log_gfx942"
}
}
catch
(
Exception
err
){
echo
"could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if
(
params
.
RUN_CK_TILE_GEMM_TESTS
){
try
{
archiveArtifacts
"perf_tile_gemm_*.log"
if
(
arch_type
==
1
){
stash
includes:
"perf_tile_gemm_**_fp16_gfx90a.log"
,
name:
"perf_tile_gemm_log_gfx90a"
}
else
if
(
arch_type
==
2
){
stash
includes:
"perf_tile_gemm_**_fp16_gfx942.log"
,
name:
"perf_tile_gemm_log_gfx942"
}
}
}
catch
(
Exception
err
){
catch
(
Exception
err
){
echo
"could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
echo
"could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
...
@@ -630,6 +656,15 @@ def process_results(Map conf=[:]){
...
@@ -630,6 +656,15 @@ def process_results(Map conf=[:]){
echo
"could not locate the FMHA performance logs: ${err.getMessage()}."
echo
"could not locate the FMHA performance logs: ${err.getMessage()}."
}
}
}
}
if
(
params
.
RUN_CK_TILE_GEMM_TESTS
){
try
{
unstash
"perf_tile_gemm_log_gfx942"
unstash
"perf_tile_gemm_log_gfx90a"
}
catch
(
Exception
err
){
echo
"could not locate the GEMM performance logs: ${err.getMessage()}."
}
}
if
(
params
.
RUN_FULL_QA
){
if
(
params
.
RUN_FULL_QA
){
// unstash perf files to master
// unstash perf files to master
unstash
"ckprofiler_0.2.0_amd64.deb"
unstash
"ckprofiler_0.2.0_amd64.deb"
...
@@ -956,7 +991,7 @@ pipeline {
...
@@ -956,7 +991,7 @@ pipeline {
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
setup_args
=
"NO_CK_BUILD"
execute_args
=
""" ../script/cmake-ck-dev.sh ../ gfx90a && \
execute_args
=
""" ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_gemm_basic && \
make -j64 tile_example_gemm_basic
tile_example_gemm_universal
&& \
cd ../ &&
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
}
}
...
@@ -975,7 +1010,7 @@ pipeline {
...
@@ -975,7 +1010,7 @@ pipeline {
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
setup_args
=
"NO_CK_BUILD"
execute_args
=
""" ../script/cmake-ck-dev.sh ../ gfx942 && \
execute_args
=
""" ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_gemm_basic && \
make -j64 tile_example_gemm_basic
tile_example_gemm_universal
&& \
cd ../ &&
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
}
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
3aec6f03
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_universal
_gemm
EXCLUDE_FROM_ALL universal_gemm.cpp
)
add_executable
(
tile_example_
gemm_
universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
example/ck_tile/03_gemm/README.md
View file @
3aec6f03
...
@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
...
@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_
mem_pipeline
-j
make tile_example_gemm_
universal
-j
```
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
&
`build/bin/tile_example_gemm_universal`
## example
## example
```
```
...
@@ -22,6 +22,9 @@ args:
...
@@ -22,6 +22,9 @@ args:
-m m dimension (default:1024)
-m m dimension (default:1024)
-n n dimension (default:2048)
-n n dimension (default:2048)
-k k dimension (default:64)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-stride_c Tensor C stride (default:0)
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
3aec6f03
...
@@ -9,8 +9,6 @@
...
@@ -9,8 +9,6 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
3aec6f03
...
@@ -8,6 +8,27 @@
...
@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
0 → 100755
View file @
3aec6f03
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
0
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
0 → 100755
View file @
3aec6f03
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
0
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/run_full_test.sh
View file @
3aec6f03
...
@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
...
@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
export
GPU_arch
=
$4
export
GPU_arch
=
$4
echo
'GPU_arch: '
$GPU_arch
echo
'GPU_arch: '
$GPU_arch
function
print_log_header
(){
rm
-f
$1
;
echo
'On branch '
$3
&>
$1
;
echo
'Node name: '
$4
>>
$1
;
# get GPU architecture and compute units from rocminfo
echo
-n
"GPU_arch: "
>>
$1
;
rocminfo |
grep
"Name:"
|
grep
"gfx"
>>
$1
;
rocminfo |
grep
"Compute Unit:"
>>
$1
;
hipcc
--version
|
grep
-e
'HIP version'
>>
$1
;
echo
'Environment type: '
$2
>>
$1
;
/opt/rocm/bin/amdclang++
--version
|
grep
-e
'InstalledDir'
>>
$1
;
}
# run verification tests
# run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
# run performance benchmarks
export
gemm_basic_log
=
"perf_tile_gemm_basic_fp16_
$GPU_arch
.log"
print_log_header
$gemm_basic_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 |
tee
-a
$gemm_basic_log
# We do not have a performance benchmark for gemm yet. Will add it in the future.
export
gemm_mem_pipeline_log
=
"perf_tile_gemm_mem_pipeline_fp16_
$GPU_arch
.log"
\ No newline at end of file
print_log_header
$gemm_mem_pipeline_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 |
tee
-a
$gemm_mem_pipeline_log
example/ck_tile/03_gemm/script/smoke_test.sh
→
example/ck_tile/03_gemm/script/smoke_test
_basic
.sh
View file @
3aec6f03
...
@@ -32,4 +32,4 @@ set -x
...
@@ -32,4 +32,4 @@ set -x
run_fp16_tests
run_fp16_tests
set
+x
set
+x
\ No newline at end of file
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
0 → 100755
View file @
3aec6f03
#!/bin/bash
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
32 64
;
do
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set
-x
run_fp16_tests
set
+x
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
3aec6f03
...
@@ -9,18 +9,9 @@
...
@@ -9,18 +9,9 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
...
@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
using
GemmPipelineProblem
=
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
#endif
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
@@ -89,26 +79,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -89,26 +79,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
scheduler
=
GEMM_PIPELINE_SCHEDULER
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
BDataType
,
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
AccDataType
,
#endif
GemmShape
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
Traits
,
BDataType
,
scheduler
,
AccDataType
,
has_hot_loop_v
,
GemmShape
,
tail_number_v
>
;
Traits
,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
ck_tile
::
GemmPipelineScheduler
::
Interwave
,
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
#endif
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
...
example/ck_tile/05_reduce/reduce.cpp
View file @
3aec6f03
...
@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr
ck_tile
::
index_t
kBlockSize
=
512
;
constexpr
ck_tile
::
index_t
kBlockSize
=
256
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
3aec6f03
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
}
}
if
(
!
(
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
&&
const
bool
is_w_pad_zero
=
arg
.
input_left_pads_
[
NDimSpatial
-
1
]
==
0
&&
arg
.
input_right_pads_
[
NDimSpatial
-
1
]
==
0
;
const
auto
X
=
arg
.
filter_spatial_lengths_
[
NDimSpatial
-
1
];
const
bool
XC_access_allowed
=
arg
.
Conv_G_
==
1
&&
(
arg
.
Conv_C_
*
X
)
%
BBlockTransferSrcScalarPerVector
==
0
&&
is_w_pad_zero
;
if
(
!
((
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
||
XC_access_allowed
)
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
))
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
if
(
!
(
arg
.
Conv_K_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideA_
==
1
))
if
(
!
(
arg
.
Conv_K_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideA_
==
1
&&
NumGroupsToMerge
>
1
))
{
{
return
false
;
return
false
;
}
}
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideB_
==
1
))
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideB_
==
1
&&
NumGroupsToMerge
>
1
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
3aec6f03
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
{
return
false
;
return
false
;
}
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNWC_GKXC_GNWK
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNWC_GKXC_GNWK
<
InLayout
,
WeiLayout
,
OutLayout
>
())
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
3aec6f03
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I10
=
Number
<
10
>
{};
static
constexpr
auto
I12
=
Number
<
12
>
{};
static
constexpr
auto
I13
=
Number
<
13
>
{};
static
constexpr
auto
I14
=
Number
<
14
>
{};
static
constexpr
auto
I16
=
Number
<
16
>
{};
static
constexpr
index_t
PackedSize
=
[]()
{
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
...
@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_oob_thread_scratch_tuple_
(
thread_scratch_id
)
src_oob_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
bool
>(
src_data_idx_seq
,
is_src_valid
);
.
template
SetAsType
<
bool
>(
src_data_idx_seq
,
is_src_valid
);
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
dst_vector_type
op_r_v
;
...
@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
auto
src_vector_container
=
src_vector_type
{
using
VectorSizeLookupTable
=
Tuple
<
Sequence
<>
,
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
Sequence
<
I1
>
,
Sequence
<
I2
>
,
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
Sequence
<
I2
,
I1
>
,
// apply the src elementwise op and convert to DstData under the hood if needed
Sequence
<
I4
>
,
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
Sequence
<
I4
,
I1
>
,
src_vector_container
.
template
AsType
<
src_elem_op_vec_t
>()[
idx
]);
Sequence
<
I4
,
I2
>
,
});
Sequence
<
I4
,
I2
,
I1
>
,
Sequence
<
I8
>
,
Sequence
<
I8
,
I1
>
,
Sequence
<
I8
,
I2
>
,
Sequence
<
I8
,
I2
,
I1
>
,
Sequence
<
I8
,
I4
>
,
Sequence
<
I8
,
I4
,
I1
>
,
Sequence
<
I8
,
I4
,
I2
>
,
Sequence
<
I8
,
I4
,
I2
,
I1
>
,
Sequence
<
I16
>>
;
using
VectorOffsetsLookupTable
=
Tuple
<
Sequence
<>
,
Sequence
<
I0
>
,
Sequence
<
I0
>
,
Sequence
<
I0
,
I2
>
,
Sequence
<
I0
>
,
Sequence
<
I0
,
I4
>
,
Sequence
<
I0
,
I4
>
,
Sequence
<
I0
,
I4
,
I6
>
,
Sequence
<
I0
>
,
Sequence
<
I0
,
I8
>
,
Sequence
<
I0
,
I8
>
,
Sequence
<
I0
,
I8
,
I10
>
,
Sequence
<
I0
,
I8
>
,
Sequence
<
I0
,
I8
,
I12
>
,
Sequence
<
I0
,
I8
,
I12
>
,
Sequence
<
I0
,
I8
,
I12
,
I14
>
,
Sequence
<
I0
>>
;
static_for
<
0
,
tuple_element_t
<
SrcScalarPerVector
,
VectorSizeLookupTable
>::
Size
(),
1
>
{}(
[
&
](
auto
v_idx
)
{
constexpr
auto
VectorLoadSize
=
tuple_element_t
<
SrcScalarPerVector
,
VectorSizeLookupTable
>::
At
(
v_idx
);
constexpr
auto
LoadOffset
=
tuple_element_t
<
SrcScalarPerVector
,
VectorOffsetsLookupTable
>::
At
(
v_idx
);
using
src_vector_container
=
vector_type_maker_t
<
SrcData
,
VectorLoadSize
>
;
using
src_vector_container_t
=
typename
src_vector_container
::
type
;
src_vector_container
src_vector
=
src_vector_container
{
src_buf
.
template
Get
<
src_vector_container_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
+
LoadOffset
,
true
)};
static_for
<
0
,
VectorLoadSize
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if
// needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
+
LoadOffset
),
src_vector
.
template
AsType
<
src_elem_op_vec_t
>()[
idx
]);
});
});
// copy data from src_vector_container into src_thread_scratch_
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
src_thread_scratch_tuple_
(
thread_scratch_id
)
...
...
include/ck/utility/data_type.hpp
View file @
3aec6f03
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -327,7 +327,77 @@ struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
...
@@ -327,7 +327,77 @@ struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d3_t
__attribute__
((
ext_vector_type
(
3
)));
using
type
=
d3_t
;
union
{
d3_t
d3_
;
StaticallyIndexedArray
<
d1_t
,
3
>
d1x3_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
StaticallyIndexedArray
<
d3_t
,
1
>
d3x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d3_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d3_t
>::
value
)
{
return
data_
.
d3x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d3_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d3_t
>::
value
)
{
return
data_
.
d3x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -397,7 +467,159 @@ struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
...
@@ -397,7 +467,159 @@ struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
5
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d5_t
__attribute__
((
ext_vector_type
(
5
)));
using
type
=
d5_t
;
union
{
d5_t
d5_
;
StaticallyIndexedArray
<
d1_t
,
5
>
d1x5_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d5_t
,
1
>
d5x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d5_t
>::
value
)
{
return
data_
.
d5x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d5_t
>::
value
)
{
return
data_
.
d5x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
7
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d7_t
__attribute__
((
ext_vector_type
(
7
)));
using
type
=
d7_t
;
union
{
d7_t
d7_
;
StaticallyIndexedArray
<
d1_t
,
7
>
d1x7_
;
StaticallyIndexedArray
<
d2_t
,
3
>
d2x3_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d7_t
,
1
>
d7x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x7_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d7_t
>::
value
)
{
return
data_
.
d7x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x7_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d7_t
>::
value
)
{
return
data_
.
d7x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -479,7 +701,89 @@ struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
...
@@ -479,7 +701,89 @@ struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
13
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d13_t
__attribute__
((
ext_vector_type
(
13
)));
using
type
=
d13_t
;
union
{
d13_t
d13_
;
StaticallyIndexedArray
<
d1_t
,
13
>
d1x13_
;
StaticallyIndexedArray
<
d4_t
,
3
>
d4x3_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
StaticallyIndexedArray
<
d13_t
,
1
>
d13x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d13_t
>::
value
)
{
return
data_
.
d13x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d13_t
>::
value
)
{
return
data_
.
d13x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
3aec6f03
...
@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
Policy
::
template
GetSmemSizeQ
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
...
@@ -328,8 +323,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -328,8 +323,7 @@ struct BlockFmhaPipelineQSKSVS
});
});
}
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
{
// tail
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
block_sync_lds
();
block_sync_lds
();
...
@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
__builtin_amdgcn_sched_barrier
(
0
);
// l{j}, Oacc{j}
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS
}
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
if
constexpr
(
k1_loops
>
1
)
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
View file @
3aec6f03
...
@@ -9,11 +9,33 @@
...
@@ -9,11 +9,33 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQSKSVSDefaultPolicy
=
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
KDataType
);
}
// namespace ck_tile
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
{
return
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
max
(
GetSmemSizeQ
<
Problem
>
()
+
GetSmemSizeK
<
Problem
>
(),
GetSmemSizeV
<
Problem
>
())
+
GetSmemSizeDropout
<
Problem
>
();
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
3aec6f03
...
@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
return
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
QDataType
);
// use dwordx4. TODO: change this
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
...
@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
3aec6f03
...
@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1
...
@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
// C += A * B
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
const
ABlockTensor
&
a_block_tensor
,
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
const
BBlockTensor
&
b_block_tensor
)
const
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
...
@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1
...
@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistributionEncode
()
{
constexpr
auto
b_block_outer_dstr_encoding
=
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
...
@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1
...
@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockDistributionEncode
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1
...
@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
1
,
1
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a
_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
return
c
_block_dstr_encode
;
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
}
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
// C += A * B
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
a_block_dstr_encode
=
MakeABlockDistributionEncode
();
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
MakeBBlockDistributionEncode
();
constexpr
auto
c_block_dstr_encode
=
MakeCBlockDistributionEncode
();
// check ABC-block-distribution
// check ABC-block-distribution
static_assert
(
static_assert
(
...
@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1
...
@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment