Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf5fe7b3
Unverified
Commit
bf5fe7b3
authored
Mar 08, 2023
by
rocking5566
Committed by
GitHub
Mar 08, 2023
Browse files
Merge branch 'develop' into conv_dlops/quantization
parents
86c32464
9096b1c7
Changes
27
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
740 additions
and
37 deletions
+740
-37
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
+2
-0
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
...xample/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
+232
-0
docs/Doxyfile
docs/Doxyfile
+5
-3
docs/source/API_Reference_Guide.rst
docs/source/API_Reference_Guide.rst
+35
-6
docs/source/conf.py
docs/source/conf.py
+4
-1
docs/source/refs.bib
docs/source/refs.bib
+7
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+10
-5
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+10
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+29
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+4
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+6
-1
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+14
-15
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp
...y/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp
+136
-0
library/include/ck/library/utility/device_memory.hpp
library/include/ck/library/utility/device_memory.hpp
+4
-0
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+18
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
...eration_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
+6
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
...ouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
+104
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
...ouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
+104
-0
No files found.
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
0 → 100644
View file @
bf5fe7b3
add_executable
(
client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp
)
target_link_libraries
(
client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations
)
\ No newline at end of file
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
0 → 100644
View file @
bf5fe7b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
FastGelu
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
{
std
::
mt19937
gen
(
19391
);
std
::
uniform_int_distribution
<>
distrib
(
1
,
10
);
int
group_count
=
distrib
(
gen
);
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ms
.
push_back
(
256
+
256
*
distrib
(
gen
));
Ns
.
push_back
(
256
+
256
*
distrib
(
gen
));
Ks
.
push_back
(
128
+
128
*
distrib
(
gen
));
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
std
::
vector
<
SimpleDeviceMem
>
a_dev_bufs
,
b_dev_bufs
,
e_dev_bufs
;
a_dev_bufs
.
reserve
(
group_count
);
b_dev_bufs
.
reserve
(
group_count
);
e_dev_bufs
.
reserve
(
group_count
);
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_e
;
p_a
.
reserve
(
group_count
);
p_b
.
reserve
(
group_count
);
p_e
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{}));
b_dev_bufs
.
emplace_back
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{}));
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideEs
[
i
],
{}});
p_a
.
push_back
(
a_dev_bufs
[
i
].
GetDeviceBuffer
());
p_b
.
push_back
(
b_dev_bufs
[
i
].
GetDeviceBuffer
());
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
auto
p_ds
=
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
{};
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
gemm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
j
]
*
Ns
[
j
]
*
Ks
[
j
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
j
]
*
Ks
[
j
]
+
sizeof
(
BDataType
)
*
Ks
[
j
]
*
Ns
[
j
]
+
sizeof
(
EDataType
)
*
Ms
[
j
]
*
Ns
[
j
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
if
(
found
)
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
gemm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
docs/Doxyfile
View file @
bf5fe7b3
...
...
@@ -775,8 +775,10 @@ WARN_LOGFILE =
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched.
INPUT = ../library/include \
../library/include/internal
INPUT = ../include/ck/tensor_operation/gpu/grid \
../include/ck/tensor_operation/gpu/block \
../include/ck/tensor_operation/gpu/thread \
../library/include/ck/library/utility
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
...
...
@@ -845,7 +847,7 @@ FILE_PATTERNS = *.c \
# be searched for input files as well.
# The default value is: NO.
RECURSIVE =
NO
RECURSIVE =
YES
# The EXCLUDE tag can be used to specify files and/or directories that should be
# excluded from the INPUT source files. This way you can easily exclude a
...
...
docs/source/API_Reference_Guide.rst
View file @
bf5fe7b3
===================
*******************
API Reference Guide
===================
*******************
------------
=================
Introduction
------------
=================
This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design
principles that are used to write new classes that extend CK functionality.
...
...
@@ -16,8 +16,37 @@ Using CK API
This section describes how to use the CK library API.
-----------------
=================
CK Datatypes
=================
-----------------
DeviceMem
-----------------
[TODO]
\ No newline at end of file
.. doxygenstruct:: DeviceMem
---------------------------
Kernels For Flashattention
---------------------------
The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are
used in the CK GPU implementation of Flashattention.
**Gridwise classes**
.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
**Blockwise classes**
.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1
.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2
.. doxygenstruct:: ck::BlockwiseSoftmax
**Threadwise classes**
.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic
.. bibliography::
\ No newline at end of file
docs/source/conf.py
View file @
bf5fe7b3
...
...
@@ -59,10 +59,13 @@ if read_the_docs_build:
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions
=
[
'sphinx.ext.mathjax'
,
'breathe'
]
extensions
=
[
'sphinx.ext.mathjax'
,
'breathe'
,
'sphinxcontrib.bibtex'
]
breathe_projects
=
{
"CK"
:
"../docBin/xml"
}
breathe_default_project
=
"CK"
bibtex_bibfiles
=
[
'refs.bib'
]
# Add any paths that contain templates here, relative to this directory.
templates_path
=
[
'_templates'
]
...
...
docs/source/refs.bib
0 → 100644
View file @
bf5fe7b3
@article
{
dao2022flashattention
,
title
=
{Flashattention: Fast and memory-efficient exact attention with io-awareness}
,
author
=
{Dao, Tri and Fu, Daniel Y and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}
,
journal
=
{arXiv preprint arXiv:2205.14135}
,
year
=
{2022}
}
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
bf5fe7b3
...
...
@@ -622,11 +622,16 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
}
};
// Blockwise gemm supporting
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
// source buffer
// 3. configurable k index starting position and step size after each FMA/XDL instruction
/**
* @brief Blockwise gemm
*
* Supports
* 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
* 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
* source buffer
* 3. configurable k index starting position and step size after each FMA/XDL instruction
*/
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
bf5fe7b3
...
...
@@ -12,6 +12,16 @@
namespace
ck
{
/**
* @brief Blockwise softmax
*
* @tparam BlockSize Block size
* @tparam AccDataType Accumulator data type
* @tparam ThreadMap_M_K Thread id to m_k
* @tparam ThreadClusterDesc_M_K Threadwise cluster descriptor
* @tparam ThreadSliceDesc_M_K Threadwise slices descriptor
* @tparam IgnoreNaN Flag to ignore NaN, false by default
*/
template
<
index_t
BlockSize
,
typename
AccDataType
,
typename
ThreadMap_M_K
,
// thread_id to m_k
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
bf5fe7b3
...
...
@@ -11,10 +11,15 @@
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
/**
* @brief Blockwise data transfer
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
bf5fe7b3
...
...
@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
a_mtx_mraw_kraw_
.
emplace_back
(
M
,
K
);
b_mtx_nraw_kraw_
.
emplace_back
(
N
,
K
);
if
(
M
==
0
)
{
skipped_group_count_
++
;
...
...
@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEElementwiseOperation
c_element_op_
;
std
::
vector
<
GemmBiasTransKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
a_mtx_mraw_kraw_
;
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
b_mtx_nraw_kraw_
;
index_t
grid_size_
;
};
...
...
@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return
false
;
}
return
true
;
bool
supported
=
true
;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if
constexpr
(
GemmSpec
!=
GemmSpecialization
::
Default
)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const
auto
a_raw_vector_dim
=
ABlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
b_raw_vector_dim
=
BBlockTransferSrcVectorDim
!=
1
?
1
:
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
const
auto
a_vector_dim
=
arg
.
a_mtx_mraw_kraw_
[
i
].
At
(
Number
<
a_raw_vector_dim
>
{});
const
auto
b_vector_dim
=
arg
.
b_mtx_nraw_kraw_
[
i
].
At
(
Number
<
b_raw_vector_dim
>
{});
supported
=
supported
&
(
a_vector_dim
%
ABlockTransferSrcScalarPerVector
==
0
);
supported
=
supported
&
(
b_vector_dim
%
BBlockTransferSrcScalarPerVector
==
0
);
}
}
return
supported
;
}
// polymorphic
...
...
@@ -661,7 +687,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
NXdlPerWave
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
bf5fe7b3
...
...
@@ -18,6 +18,10 @@
namespace
ck
{
/**
* @brief Gridwise gemm + softmax + gemm fusion
*
*/
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
bf5fe7b3
...
...
@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord
src_ref_coord_
;
};
// Do NOT involve any tensor coordinates with StaticBuffer
/**
* @brief Threadwise data transfer
*
* Do NOT involve any tensor coordinates with StaticBuffer
*
*/
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
bf5fe7b3
...
...
@@ -93,6 +93,7 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
bf5fe7b3
...
...
@@ -74,8 +74,7 @@ template <typename ALayout,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
...
...
@@ -83,9 +82,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
BDataType
,
Empty_Tuple
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
PassThrough
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemm
<
ALayout
,
BLayout
,
...
...
@@ -95,9 +94,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
BDataType
,
Empty_Tuple
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
PassThrough
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp
0 → 100644
View file @
bf5fe7b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
instances
);
// GroupedGEMM + GELU
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGroupedGemm
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/device_memory.hpp
View file @
bf5fe7b3
...
...
@@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
}
}
/**
* @brief Container for storing data in GPU device memory
*
*/
struct
DeviceMem
{
DeviceMem
()
=
delete
;
...
...
library/include/ck/library/utility/fill.hpp
View file @
bf5fe7b3
...
...
@@ -100,6 +100,15 @@ struct FillMonotonicSeq
return
tmp
;
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillMonotonicSeq
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
...
...
@@ -112,6 +121,15 @@ struct FillConstant
{
std
::
fill
(
first
,
last
,
value_
);
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillConstant
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
}
// namespace utils
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
0 → 100644
View file @
bf5fe7b3
add_instance_library
(
device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
)
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
0 → 100644
View file @
bf5fe7b3
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
0 → 100644
View file @
bf5fe7b3
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment