Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
51ae4aa2
Commit
51ae4aa2
authored
Oct 16, 2023
by
Adam Osewski
Browse files
DeviceOp + GridwiseGemm Draft GroupedGEMM+SplitK+TileLoop
* First Part: accumulation across tiles in CThreadBuffer
parent
7e2b9da3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2460 additions
and
0 deletions
+2460
-0
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+5
-0
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+334
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
...tion/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
+136
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+3
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+1010
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+972
-0
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
51ae4aa2
...
...
@@ -35,6 +35,11 @@ add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fi
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
endif
()
add_example_executable
(
example_grouped_gemm_multiple_d_splitk_xdl_fp16 grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_splitk_xdl_fp16
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
0 → 100644
View file @
51ae4aa2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F32
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| AThreadTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BThreadTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Run| | | | | | | | Run| | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
false
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
false
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
// clang-format on
struct
ProblemSize
final
{
std
::
vector
<
ck
::
index_t
>
Ms
;
std
::
vector
<
ck
::
index_t
>
Ns
;
std
::
vector
<
ck
::
index_t
>
Ks
;
std
::
vector
<
ck
::
index_t
>
stride_As
;
std
::
vector
<
ck
::
index_t
>
stride_Bs
;
std
::
vector
<
ck
::
index_t
>
stride_Cs
;
ck
::
index_t
group_count
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
int
k_batch
=
1
;
bool
time_kernel
=
false
;
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
auto
group_count
=
problem_size
.
group_count
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
void
*>
p_Cs
;
std
::
vector
<
const
void
*>
p_As
;
std
::
vector
<
const
void
*>
p_Bs
;
gemm_descs
.
reserve
(
group_count
);
p_As
.
reserve
(
group_count
);
p_Bs
.
reserve
(
group_count
);
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck
::
utils
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_tensors
[
i
]);
}
}
using
GroupedGemmKernelArgument
=
ck
::
tensor_operation
::
device
::
GroupedGemmMultipleDKernelArguments
<>
;
std
::
vector
<
GroupedGemmKernelArgument
>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
problem_size
.
Ns
[
i
]
*
problem_size
.
Ks
[
i
]));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ns
[
i
]));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
(),
a_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
));
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
(),
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
));
c_tensors_device
[
i
]
->
SetZero
();
p_As
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Bs
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
({
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
],
{}});
grouped_gemm_kernel_args_
.
push_back
({
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
b_tensors_device
[
i
]
->
GetDeviceBuffer
(),
{},
c_tensors_device
[
i
]
->
GetDeviceBuffer
(),
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
{},
problem_size
.
stride_Cs
[
i
]});
}
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
hip_check_error
(
hipMemcpy
(
gemm_arg_dev_mem
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
gemm
.
GetDeviceKernelArgSize
(
&
argument
),
hipMemcpyHostToDevice
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_arg_dev_mem
.
GetDeviceBuffer
());
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
(),
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()
*
sizeof
(
EDataType
));
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
}
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
std
::
vector
<
ck
::
index_t
>
Ms
{
64
};
problem_size
.
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
problem_size
.
Ns
.
push_back
(
128
);
problem_size
.
Ks
.
push_back
(
128
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg4: k_batch (> 0)
\n
"
);
exit
(
0
);
}
return
!
run_grouped_gemm
(
problem_size
,
config
);
}
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
0 → 100644
View file @
51ae4aa2
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmMultipleDKernelArguments
{
__host__
__device__
GroupedGemmMultipleDKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmMultipleDSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
View file @
51ae4aa2
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
0 → 100644
View file @
51ae4aa2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include <ck/utility/work_scheduling.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] p_workspace Pointer to the auxilliary workgroup workspace used to store
/// partial results.
/// @param[in] tile_count The overall number of output tiles we divided all groups into.
/// @param[in] k_batch The number of batches we split the K dimension into.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for grouped gemm calculation and work
/// distribution.
/// @tparam FloatA Input tensor A elements' data type.
/// @tparam FloatB Input tensor B elements' data type.
/// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
///
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
Block2ETileMapKSplit
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
void
*
const
__restrict__
p_workspace
,
const
index_t
tile_count
,
const
index_t
k_batch
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
[[
maybe_unused
]]
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
uint32_t
*
const
__restrict__
p_flags
=
reinterpret_cast
<
uint32_t
*
const
__restrict__
>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
typename
GridwiseGemm
::
AccType
)));
StridedReductionTileLoop
work_scheduler
{
tile_count
,
p_flags
};
// early exit if no work.
if
(
work_scheduler
.
tile_id_
>=
tile_count
)
return
;
if
(
get_thread_global_1d_id
()
<
work_scheduler
.
GetFlagCount
(
k_batch
))
p_flags
[
get_thread_global_1d_id
()]
=
0
;
index_t
group_id
=
0
;
index_t
offset
=
0
;
auto
M
=
gemm_desc_ptr
[
group_id
].
M
;
auto
N
=
gemm_desc_ptr
[
group_id
].
N
;
auto
b2c_tile_map
=
Block2ETileMapKSplit
(
M
,
N
,
k_batch
);
index_t
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
grid_size_grp
;
do
{
// Find corresponding GEMM group for our tile
while
(
!
(
work_scheduler
.
tile_id_
>=
gemm_tile_id_start
&&
work_scheduler
.
tile_id_
<
gemm_tile_id_end
))
{
offset
+=
grid_size_grp
;
group_id
++
;
M
=
gemm_desc_ptr
[
group_id
].
M
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
b2c_tile_map
=
Block2ETileMapKSplit
(
M
,
N
,
k_batch
);
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
gemm_tile_id_start
=
offset
;
gemm_tile_id_end
=
offset
+
grid_size_grp
;
}
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto
gridwise_gemm
=
GridwiseGemm
();
auto
&
results_buffer
=
gridwise_gemm
.
GetCThreadBuffer
();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
do
{
// just accumulate results in registers!
gridwise_gemm
.
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
k_batch
,
b2c_tile_map
);
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
// Store partial results to auxilliary workspace.
// make results buffer tensor descriptor (registers).
// make workspace gmem tensor descriptor
// create ThreadGroupTransform and run copy
// if (threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
}
const
index_t
output_tile_idx
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
const
index_t
output_tile_idx_offset
=
__builtin_amdgcn_readfirstlane
(
offset
/
k_batch
);
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
[[
maybe_unused
]]
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
// if(threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// TODO do fusion, cshuffle and store results to GMEM
// gridwise_gemm.RunWrite(results_buffer,
// p_c_grid,
// M,
// N,
// K,
// StrideA,
// StrideB,
// StrideC,
// MPadded,
// NPadded,
// KPadded,
// K0,
// k_batch,
// static_cast<void*>(p_shared),
// b2c_tile_map);
}
else
{
// TODO: double buffering in order to not wait for this.
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
}
}
while
(
work_scheduler
.
HasTile
());
#else
ignore
=
gemm_descs_const
;
ignore
=
p_workspace
;
ignore
=
tile_count
;
ignore
=
k_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
:
public
DeviceGroupedGemmMultipleDSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
<
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
AThreadTransferSrcResetCoordinateAfterRun
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
KernelArguments
=
GroupedGemmMultipleDKernelArguments
<
NumDTensor
>
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
;
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
int
occupancy_num_blocks
,
int
gpu_cu_count
)
:
Argument
(
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
,
DefaultKBatch
,
occupancy_num_blocks
,
gpu_cu_count
)
{
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
index_t
kbatch
,
int
occupancy_num_blocks
,
int
gpu_cu_count
)
:
K_BATCH
{
kbatch
},
group_count_
{
0
},
skipped_group_count_
{
0
},
tile_count_
{
0
},
occupancy_num_blocks_
{
occupancy_num_blocks
},
gpu_cu_count_
{
gpu_cu_count
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
throw
std
::
runtime_error
(
"Error! group_count_ != p_As/Bs/Ds/Es size"
);
}
gemm_kernel_args_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_e
=
gemm_descs
[
i
].
stride_C_
;
auto
b2c_tile_map
=
Block2ETileMapKSplit
{
M
,
N
,
K_BATCH
};
const
index_t
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
tile_count_
+=
grid_size_grp
;
std
::
array
<
index_t
,
NumDTensor
>
stride_ds
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
gemm_descs
[
i
].
stride_Ds_
.
size
()
!=
NumDTensor
)
{
throw
std
::
runtime_error
(
"Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"
);
}
stride_ds
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
});
gemm_kernel_args_
.
emplace_back
(
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
p_Ds
[
i
],
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_ds
,
stride_e
);
}
}
/**
* @brief Set new kbatch value.
*
* @param[in] kbatch The new splitK parameter value.
*/
void
UpdateKBatch
(
index_t
kbatch
)
{
K_BATCH
=
kbatch
;
tile_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
const
auto
b2c_tile_map
=
Block2ETileMapKSplit
{
gemm_arg
.
M
,
gemm_arg
.
N
,
K_BATCH
};
const
index_t
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
gemm_arg
.
M
,
gemm_arg
.
N
);
tile_count_
+=
grid_size_grp
;
}
}
// private:
index_t
K_BATCH
;
index_t
group_count_
;
index_t
skipped_group_count_
;
// The overall number of output tiles to be processed.
index_t
tile_count_
;
const
void
*
p_dev_gemm_args_
;
int
occupancy_num_blocks_
;
int
gpu_cu_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
std
::
vector
<
KernelArguments
>
gemm_kernel_args_
;
};
struct
KernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static
constexpr
int
BLOCK_SUBSCRIPTION_FACTOR
=
1
;
static
constexpr
int
BLOCK_WAVES
=
BlockSize
/
get_warp_size
();
static
constexpr
int
CU_SIMDS
=
4
;
// Assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
2
*
CU_SIMDS
,
BLOCK_WAVES
);
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary
/// workspace.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
dev_gemm_workspace
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
float
ave_time
=
0
;
if
(
all_have_main_k_block_loop
)
{
ave_time
=
DispatchKernel
<
true
>
(
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
false
>
(
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
return
ave_time
;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see
/// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly
/// allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_dev_gemm_args_
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
arg
.
p_workspace_
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
Run
(
arg
,
arg
.
p_dev_gemm_args_
,
arg
.
p_workspace_
,
stream_config
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
private:
auto
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
bool
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
;
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
}
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
if
(
stream_config
.
log_level_
>
0
)
{
gemm_arg
.
Print
();
}
// Currently all groups use same kbatch value.
auto
kbatch
=
arg
.
K_BATCH
;
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideDs
,
gemm_arg
.
StrideE
,
kbatch
))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
K_BATCH
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
std
::
make_tuple
(
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
);
}
template
<
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
HasMainKBlockLoop
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
template
<
typename
KernelFunction
>
int
CalculateMaxOccupancyGridSize
(
const
KernelFunction
&
kernel
,
const
StreamConfig
&
stream_config
)
const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int
occ_num_blocks
=
0
;
size_t
dyn_shared_mem_per_blk
=
0
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occ_num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
));
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
if
(
stream_config
.
log_level_
>
0
)
{
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
occ_num_blocks
<<
", available CUs count: "
<<
cu_count
<<
", occup. grid size: "
<<
ck
::
math
::
min
(
occ_num_blocks
,
KernelConfig
::
CU_BLOCKS
)
*
cu_count
<<
std
::
endl
;
}
return
cu_count
*
ck
::
math
::
min
(
occ_num_blocks
,
KernelConfig
::
CU_BLOCKS
);
}
template
<
typename
KernelFunction
>
float
LaunchKernel
(
const
KernelFunction
&
kernel
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
{
int
max_occupancy_grid_size
=
CalculateMaxOccupancyGridSize
(
kernel
,
stream_config
);
// We launch the smaller number of workgroups from acutally needed tiles and the
// number of workgroups that maximize the GPU occupancy. That is because for some tile
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
if
(
stream_config
.
log_level_
>
0
)
{
const
index_t
grid_size
=
ck
::
math
::
min
(
arg
.
tile_count_
,
max_occupancy_grid_size
);
const
index_t
tiles_per_block
=
(
arg
.
tile_count_
+
grid_size
-
1
)
/
grid_size
;
std
::
cout
<<
"tile_count: "
<<
arg
.
tile_count_
<<
", tiles_per_block: "
<<
tiles_per_block
<<
std
::
endl
;
}
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
ck
::
math
::
min
(
arg
.
tile_count_
,
max_occupancy_grid_size
)),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
dev_gemm_workspace
,
arg
.
tile_count_
,
arg
.
K_BATCH
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
#if DEBUG_LOG
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideDs
,
gemm_arg
.
StrideE
,
arg
.
K_BATCH
);
if
(
not
group_arg_valid
)
{
#if DEBUG_LOG
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
gemm_arg
.
Print
();
#endif // DEBUG_LOG
}
supported
=
supported
&&
group_arg_valid
;
}
return
supported
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
CDEElementwiseOperation
cde_elementwise_op
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
true
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
Argument
{
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_elementwise_op
,
b_elementwise_op
,
cde_elementwise_op
,
occupancy
,
num_cu
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
CDEElementwiseOperation
cde_elementwise_op
)
override
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
true
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_elementwise_op
,
b_elementwise_op
,
cde_elementwise_op
,
occupancy
,
num_cu
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemm_XdlSplitKTileLoop"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
p_dev_kernel_args
)
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
int
occ_grid_size
=
arg
.
gpu_cu_count_
*
std
::
min
(
arg
.
occupancy_num_blocks_
,
KernelConfig
::
CU_BLOCKS
);
int
grid_size
=
std
::
min
(
arg
.
tile_count_
,
occ_grid_size
);
int
tiles_per_block
=
(
arg
.
tile_count_
+
grid_size
-
1
)
/
grid_size
;
int
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel.
size_t
size_bytes
=
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
AccDataType
),
grid_size
)
+
flag_count
*
sizeof
(
uint32_t
);
return
size_bytes
;
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
0 → 100644
View file @
51ae4aa2
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
>
class
GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
public:
using
AccType
=
AccDataType
;
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
)
{
return
math
::
integer_least_multiple
(
K
,
KPerBlock
*
K_Batch
);
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0PerBlock
*
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0PerBlock
*
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
MPad
=
CalculateMPadded
(
M
);
const
auto
KPad
=
CalculateKPadded
(
K
,
KBatch
);
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
AK0
=
KPad
/
(
KBatch
*
AK1
);
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
AK0
,
AK1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
NPad
=
CalculateNPadded
(
N
);
const
auto
KPad
=
CalculateKPadded
(
K
,
KBatch
);
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
BK0
=
KPad
/
(
KBatch
*
BK1
);
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
BK0
,
BK1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
private:
using
AGridDesc_KBatch_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
1
,
1
,
1
,
1
))
>
;
using
BGridDesc_KBatch_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
1
,
1
,
1
,
1
))
>
;
using
ABlockDesc_KBatch_AK0PerB_MPerB_AK1
=
remove_cvref_t
<
decltype
(
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
())
>
;
using
BBlockDesc_KBatch_BK0PerB_NPerB_BK1
=
remove_cvref_t
<
decltype
(
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
())
>
;
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ComputeType
,
AGridDesc_KBatch_AK0_M_AK1
,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1
,
ABlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
ComputeType
,
BGridDesc_KBatch_BK0_N_BK1
,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1
,
BBlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
())
>
;
BlockwiseGemmT
blockwise_gemm_
{};
public:
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ComputeType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// E desc for destination in blockwise copy
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
)
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
ignore
=
StrideDs
;
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
))
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
TensorDataLayout
>
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
TensorDataLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
TensorDataLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
constexpr
auto
&
GetCThreadBuffer
()
{
return
blockwise_gemm_
.
GetCThreadBuffer
();
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
__device__
void
RunGEMM
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
void
*
__restrict__
p_shared
,
[[
maybe_unused
]]
const
index_t
KBatch
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AGridDesc_KBatch_AK0_M_AK1
&
a_grid_desc_kbatch_ak0_m_ak1
,
const
BGridDesc_KBatch_BK0_N_BK1
&
b_grid_desc_kbatch_bk0_n_bk1
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
());
// divide block work by [M, N, K]
const
auto
block_work_idx
=
block_2_etile_map
.
GetBottomIndex
();
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_kbatch_ak0_m_ak1
=
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_kbatch_bk0_n_bk1
=
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ABlockwiseCopy
(
a_grid_desc_kbatch_ak0_m_ak1
,
make_multi_index
(
kbatch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_kbatch_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BBlockwiseCopy
(
b_grid_desc_kbatch_bk0_n_bk1
,
make_multi_index
(
kbatch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_kbatch_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
((
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_kbatch_ak0_m_ak1
,
a_block_desc_kbatch_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_kbatch_bk0_n_bk1
,
b_block_desc_kbatch_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm_
,
c_thread_buf
,
num_k_block_main_loop
);
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
__device__
void
RunGEMM
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
RunGEMM
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_shared
,
KBatch
,
a_element_op
,
b_element_op
,
a_grid_desc_kbatch_ak0_m_ak1
,
b_grid_desc_kbatch_bk0_n_bk1
,
block_2_etile_map
);
}
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// typename DsDataType_,
// typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename CDEElementwiseOperation_,
// typename Block2ETileMap>
// __device__ void RunWrite(CThreadBufer c_thread_buf,
// const EDataType* __restrict__ p_workspace,
// DsGridPointer p_ds_grid,
// EDataType* __restrict__ p_e_grid,
// void* __restrict__ p_shared,
// const index_t KBatch,
// const CDEElementwiseOperation_& cde_element_op,
// const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// const Block2ETileMap& block_2_etile_map)
// {
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// const auto ds_grid_buf = generate_tuple(
// [&](auto i) {
// return make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_ds_grid[i],
// ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
// },
// Number<NumDTensor_>{});
// auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
// });
// const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// // using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
// // remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// // DsGridDesc_M_N{}))>;
// // DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
// ds_grid_desc_mblock_mperblock_nblock_nperblock;
// // static_for<0, NumDTensor, 1>{}([&](auto j) {
// // ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// // });
// // const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// // shuffle C and write out
// static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
// "wrong!");
// constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// // TODO: hacky, fix it!
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// // TODO: hacky, fix it!
// // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
// constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
// constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
// constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
// constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
// constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
// GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
// make_tuple(
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
// M1, // M1 = MWave
// M2, // M2 * M3 * M4 = MPerXdl
// M3,
// M4)),
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
// N1, // N1 = NWave
// N2))), // N2 = NPerXdl
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(
// Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// // calculate origin of thread output tensor on global memory
// // blockwise GEMM c matrix starting index
// const auto c_thread_mtx_on_block =
// blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
// const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
// const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
// make_tuple(Sequence<0, 1, 2, 3, 4>{}),
// make_tuple(Sequence<0>{}));
// const auto m_thread_data_on_block_idx =
// m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
// make_multi_index(m_thread_data_on_block));
// const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
// make_tuple(Sequence<0, 1, 2>{}),
// make_tuple(Sequence<0>{}));
// const auto n_thread_data_on_block_idx =
// n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
// make_multi_index(n_thread_data_on_block));
// // shuffle: threadwise copy C from VGPR to LDS
// auto c_thread_copy_vgpr_to_lds =
// ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
// CShuffleDataType,
// decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// ck::tensor_operation::element_wise::PassThrough,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// I1,
// I1,
// M2,
// I1,
// M4,
// I1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// 7,
// 1,
// InMemoryDataOperationEnum::Set,
// 1,
// true>{
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// make_multi_index(0,
// 0,
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]),
// ck::tensor_operation::element_wise::PassThrough{}};
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_desc_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
// Number<NumDTensor_>{}));
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_buf_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_buf),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_buf[i]; },
// Number<NumDTensor_>{}));
// // tuple of starting index of C/Ds blockwise copy
// const auto idx_c_ds_block_begin = container_concat(
// make_tuple(make_multi_index(0, 0, 0, 0)),
// generate_tuple(
// [&](auto) {
// return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
// },
// Number<NumDTensor_>{}));
// // space filling curve for threadwise C in VGPR before shuffle
// constexpr auto sfc_c_vgpr =
// SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// 1,
// 1,
// M2,
// 1,
// M4,
// 1>>{};
// // space filling curve for shuffled blockwise C/D/E
// constexpr auto sfc_cde_block =
// SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
// Sequence<0, 2, 1, 3>,
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// // blockwise copy C/D/E between LDS and global
// auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// ThisThreadBlock,
// decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
// Tuple<EDataType>,
// decltype(c_ds_desc_refs),
// decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
// CDEElementwiseOperation_,
// Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// // Sequence support
// // arbitray type
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
// Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// 3, // index_t VectorDim,
// CDEShuffleBlockTransferScalarPerVector_NPerBlock,
// sequence_merge_t<
// Sequence<true>,
// uniform_sequence_gen_t<NumDTensor_,
// false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
// Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
// {c_ds_desc_refs,
// idx_c_ds_block_begin,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
// cde_element_op};
// static_for<0, num_access, 1>{}([&](auto access_id) {
// // make sure it's safe to write to LDS
// block_sync_lds();
// // each thread write its data from VGPR to LDS
// c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// c_thread_buf,
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// c_shuffle_block_buf);
// // make sure it's safe to read from LDS
// block_sync_lds();
// // each block copy its data from LDS to global
// cde_block_copy_lds_and_global.Run(
// c_ds_desc_refs,
// c_ds_buf_refs,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// tie(e_grid_buf));
// if constexpr(access_id < num_access - 1)
// {
// constexpr auto cde_lds_and_global_step =
// sfc_cde_block.GetForwardStep(access_id);
// // move on Ds
// static_for<0, NumDTensor_, 1>{}([&](auto i) {
// cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// });
// // move on E
// cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// I0,
// cde_lds_and_global_step);
// }
// });
// }
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment