Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b8d11559
Unverified
Commit
b8d11559
authored
Feb 17, 2025
by
amd-khushbu
Committed by
GitHub
Feb 17, 2025
Browse files
Merge branch 'develop' into ck_profiler_m_instances
parents
7f3fe4e7
3b230208
Changes
174
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
790 additions
and
93 deletions
+790
-93
CMakeLists.txt
CMakeLists.txt
+1
-0
Jenkinsfile
Jenkinsfile
+3
-3
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
...de/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
+61
-0
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
...lude/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
+47
-0
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+2
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+20
-0
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+15
-0
codegen/src/device_batched_gemm_softmax_gemm.cpp
codegen/src/device_batched_gemm_softmax_gemm.cpp
+38
-0
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
...vice_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
+408
-0
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+67
-35
codegen/src/types.cpp
codegen/src/types.cpp
+20
-0
codegen/test/rtc/include/rtc/hip.hpp
codegen/test/rtc/include/rtc/hip.hpp
+1
-0
example/01_gemm/gemm_xdl_streamk.cpp
example/01_gemm/gemm_xdl_streamk.cpp
+4
-0
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+9
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8
-1
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+2
-1
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+5
-2
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+12
-6
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+64
-44
No files found.
CMakeLists.txt
View file @
b8d11559
...
...
@@ -92,6 +92,7 @@ endif()
add_compile_options
(
-Wno-bit-int-extension
)
add_compile_options
(
-Wno-pass-failed
)
add_compile_options
(
-Wno-switch-default
)
add_compile_options
(
-Wno-unique-object-duplication
)
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
...
...
Jenkinsfile
View file @
b8d11559
...
...
@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){
{
echo
"Pulling down image: ${image}"
retimage
=
docker
.
image
(
"${image}"
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
pull
()
}
}
...
...
@@ -148,7 +148,7 @@ def buildDocker(install_prefix){
//force building the new docker if that parameter is true
echo
"Building image: ${image_name}"
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
push
()
}
sh
'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
...
...
@@ -162,7 +162,7 @@ def buildDocker(install_prefix){
catch
(
Exception
ex
){
echo
"Unable to locate image: ${image_name}. Building image now"
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
+
' .'
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
push
()
}
}
...
...
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
0 → 100644
View file @
b8d11559
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
acc_elem_op
=
Scale
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
bool
mask_out_upper_triangle
=
false
;
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
,
std
::
size_t
Gemm1NRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
0 → 100644
View file @
b8d11559
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
std
::
string
AccElementOp
=
Scale
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
View file @
b8d11559
...
...
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
LoopScheduler
loop_scheduler
{};
PipelineVersion
pipeline_version
{};
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
...
...
codegen/include/ck/host/operation/gemm.hpp
View file @
b8d11559
...
...
@@ -23,6 +23,26 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
TileDescGemmGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
{
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/include/ck/host/types.hpp
View file @
b8d11559
...
...
@@ -66,6 +66,20 @@ enum class GemmType
};
std
::
string
ToString
(
GemmType
gt
);
enum
class
LoopScheduler
{
Default
,
Interwave
,
};
std
::
string
ToString
(
LoopScheduler
ls
);
enum
class
PipelineVersion
{
v1
,
v2
};
std
::
string
ToString
(
PipelineVersion
pv
);
struct
TensorDesc
{
DataType
element
;
...
...
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Scale
=
"ck::tensor_operation::element_wise::Scale"
;
}
// namespace host
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
b8d11559
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
0 → 100644
View file @
b8d11559
This diff is collapsed.
Click to expand it.
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
b8d11559
...
...
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
...
...
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
// Irregular tile
{
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
1
},
// clang-format on
};
...
...
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
};
...
...
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
};
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
...
...
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
};
...
...
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
};
...
...
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
// clang-format on
};
...
...
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Irregular tile
{
S
<
1
,
16
,
1
,
4
>
,
1
},
// clang-format on
};
...
...
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
const
std
::
vector
<
std
::
tuple
<
LoopScheduler
,
PipelineVersion
>>
scheduler_pipeline_descriptions
=
{
{
LoopScheduler
::
Default
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Default
,
PipelineVersion
::
v2
},
};
for
(
auto
[
loop_scheduler
,
pipeline_version
]
:
scheduler_pipeline_descriptions
)
{
Operation_Xdl_CShuffle
x
;
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
});
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
b_elem_op
=
prob
.
BElementOp
;
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
prob
.
N
,
prob
.
K
,
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
k_per_block
);
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
// Put all values together into a single operation > store into the result vector
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
Operation_Xdl_CShuffle
x
;
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
});
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
b_elem_op
=
prob
.
BElementOp
;
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
prob
.
N
,
prob
.
K
,
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
k_per_block
);
x
.
loop_scheduler
=
loop_scheduler
;
x
.
pipeline_version
=
pipeline_version
;
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
}
}
return
result
;
}
...
...
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
"${CDEBlockTransferScalarPerVector_NPerBlock}
, ${LoopScheduler}, ${PipelineVersion}
>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
...
...
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
{
"LoopScheduler"
,
ToString
(
this
->
loop_scheduler
)},
{
"PipelineVersion"
,
ToString
(
this
->
pipeline_version
)},
};
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
...
...
codegen/src/types.cpp
View file @
b8d11559
...
...
@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
}
std
::
string
ToString
(
LoopScheduler
ls
)
{
switch
(
ls
)
{
case
LoopScheduler
::
Default
:
return
"ck::LoopScheduler::Default"
;
case
LoopScheduler
::
Interwave
:
return
"ck::LoopScheduler::Interwave"
;
}
throw
std
::
runtime_error
(
"Incorrect LoopScheduler type"
);
}
std
::
string
ToString
(
PipelineVersion
pv
)
{
switch
(
pv
)
{
case
PipelineVersion
::
v1
:
return
"ck::PipelineVersion::v1"
;
case
PipelineVersion
::
v2
:
return
"ck::PipelineVersion::v2"
;
}
throw
std
::
runtime_error
(
"Incorrect PipelineVersion type"
);
}
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
{
return
"ck::Sequence<"
+
...
...
codegen/test/rtc/include/rtc/hip.hpp
View file @
b8d11559
...
...
@@ -8,6 +8,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include <stdexcept>
namespace
rtc
{
...
...
example/01_gemm/gemm_xdl_streamk.cpp
View file @
b8d11559
...
...
@@ -27,11 +27,15 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
#else // defined(CK_USE_AMD_MFMA_GFX950)
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
#endif // defined(CK_USE_AMD_MFMA_GFX950)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
b8d11559
...
...
@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
if
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'bias'
]
cond
&=
dropout
in
[
'no'
,
'dropout_wg32'
,
'dropout_wg16'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
...
...
@@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
b8d11559
...
...
@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
if
receipt
in
(
2
,
3
)
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
if
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'bias'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
...
...
example/ck_tile/01_fmha/generate.py
View file @
b8d11559
...
...
@@ -103,7 +103,8 @@ if __name__ == "__main__":
required
=
False
,
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 2: Only generate instance for Flash attention integration"
" 2: Only generate instance for Flash attention integration
\n
"
+
\
" 4: Only generate instance for PyTorch integration"
)
args
=
parser
.
parse_args
()
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
b8d11559
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
b8d11559
...
...
@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
b8d11559
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -11,21 +11,26 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_COMPUTE
_V3
1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
COMPUTE
)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
MEMORY
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
...
...
@@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
.
insert
(
"split_k"
,
"1"
,
"splitK value"
)
.
insert
(
"init"
,
"0"
,
"0:random, 1:linear, 2:constant(1)"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
b8d11559
...
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
...
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
...
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
...
@@ -105,9 +107,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
...
...
@@ -120,9 +123,26 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
if
(
init_method
==
0
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
}
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
...
@@ -133,19 +153,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
...
...
@@ -160,9 +180,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
@@ -171,7 +191,7 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
...
...
@@ -218,9 +238,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
@@ -229,7 +249,7 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU ve
r
ification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
...
...
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment