Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa649421
Commit
fa649421
authored
Jul 16, 2023
by
Jing Zhang
Browse files
finished api
parent
e845ad4c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
231 additions
and
162 deletions
+231
-162
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+64
-38
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+14
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+127
-121
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+13
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+1
-1
library/include/ck/library/utility/device_memory.hpp
library/include/ck/library/utility/device_memory.hpp
+2
-0
library/src/utility/device_memory.cpp
library/src/utility/device_memory.cpp
+10
-0
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
View file @
fa649421
...
...
@@ -79,27 +79,15 @@ struct ExecutionConfig final
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
int
group_count
=
problem_size
.
group_count
;
auto
group_count
=
problem_size
.
group_count
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_Cs
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
problem_size
.
Ms
[
i
];
int
N
=
problem_size
.
Ns
[
i
];
int
K
=
problem_size
.
Ks
[
i
];
int
stride_A
=
problem_size
.
stride_As
[
i
];
int
stride_B
=
problem_size
.
stride_Bs
[
i
];
int
stride_C
=
problem_size
.
stride_Cs
[
i
];
gemm_descs
.
push_back
({
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
{}});
}
int
sum_of_m
=
0
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -135,21 +123,22 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_
t
i
=
0
;
i
<
g
emm_descs
.
size
()
;
i
++
)
for
(
in
t
i
=
0
;
i
<
g
roup_count
;
i
++
)
{
sum_of_m
+=
problem_size
.
Ms
[
i
];
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
]
.
stride_A
_
,
ALayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_A
s
[
i
]
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
]
.
stride_B
_
,
BLayout
{})));
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_B
s
[
i
]
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
]
.
stride_C
_
,
ELayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_C
s
[
i
]
,
ELayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
]
.
stride_C
_
,
ELayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_C
s
[
i
]
,
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M_
*
gemm_descs
[
i
].
K_
*
gemm_desc
s
[
i
]
.
N_
;
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
N
s
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
...
...
@@ -171,22 +160,47 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
using
GemmKernelArgument
=
ck
::
tensor_operation
::
device
::
GemmKernelArgument
;
std
::
vector
<
GemmKernelArgument
>
simple_gemm_kernel_args_
;
simple_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
sum_of_m
*
problem_size
.
Ks
[
i
]));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
sizeof
(
BDataType
)
*
problem_size
.
Ns
[
i
]
*
problem_size
.
Ks
[
i
]));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
sum_of_m
*
problem_size
.
Ns
[
i
]));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
(),
a_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
));
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
(),
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
));
c_tensors_device
[
i
]
->
SetZero
();
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
],
{}});
simple_gemm_kernel_args_
.
push_back
({
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
b_tensors_device
[
i
]
->
GetDeviceBuffer
(),
c_tensors_device
[
i
]
->
GetDeviceBuffer
(),
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
]});
}
auto
a_element_op
=
AElementOp
{};
...
...
@@ -196,17 +210,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
const
void
*>
p_As
=
{};
std
::
vector
<
const
void
*>
p_Bs
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_
a
,
p_
b
,
p_Ds
,
p_
c
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
p_
As
,
p_
Bs
,
p_Ds
,
p_
Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
gemm
.
SetKBatchSize
(
argument
,
8
);
hip_check_error
(
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
simple_gemm_kernel_args_
.
data
(),
gemm
.
GetWorkSpaceSize
(
&
argument
),
hipMemcpyHostToDevice
));
gemm
.
SetKBatchSize
(
argument
,
4
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -215,7 +236,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"
);
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
if
(
config
.
do_verification
)
...
...
@@ -230,7 +251,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
(),
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()
*
sizeof
(
EDataType
));
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
@@ -249,7 +272,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
...
...
@@ -267,7 +292,8 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
View file @
fa649421
...
...
@@ -8,6 +8,20 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
struct
GemmKernelArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
fa649421
...
...
@@ -83,26 +83,38 @@ __global__ void
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
group_id
*
block_size
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
block_2_ctile_map
);
const
auto
m_loops
=
local_b2c_tile_map
.
CalculateMLoops
(
c_grid_desc_m_n
);
index_t
m_id
=
0
;
do
{
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
group_id
*
block_size
,
m_id
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
block_2_ctile_map
);
m_id
+=
1
;
}
while
(
m_id
<
m_loops
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
...
...
@@ -267,11 +279,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_
=
0
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
||
0
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As || 0 != p_As.size"
);
}
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
||
0
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_Bs || 0 != p_Bs.size"
);
}
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_
As/b/c.size
"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_
Es
"
);
}
gemm_kernel_args_
.
reserve
(
group_count_
);
...
...
@@ -297,29 +319,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
grid_size_
+=
grid_size_grp
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
auto
karg
=
KernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
auto
karg
=
KernelArgument
{
p_As
.
size
()
==
0
?
nullptr
:
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
p_Bs
.
size
()
==
0
?
nullptr
:
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
...
...
@@ -349,16 +367,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
grid_size_
+=
grid_size_grp
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
...
...
@@ -378,30 +391,64 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Invoker
struct
Invoker
:
public
BaseInvoker
{
struct
SimpleGemmArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
};
float
Run
(
const
Argument
&
arg
,
const
void
*
gemm_descs_dev
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
using
GemmArgumentType
=
SimpleGemm
Argument
;
using
GemmArgumentType
=
GemmKernel
Argument
;
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
stream_config
.
log_level_
>
0
)
{
karg
.
Print
();
}
std
::
cout
<<
"Group id: "
<<
i
<<
" block_size: "
<<
arg
.
gemm_kernel_args_
[
0
].
block_end_
-
arg
.
gemm_kernel_args_
[
0
].
block_start_
<<
std
::
endl
;
auto
kbatch
=
karg
.
k_batch
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
K0
=
karg
.
K0
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k0_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
...
...
@@ -491,76 +538,35 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
std
::
vector
<
SimpleGemmArgument
>
simple_gemm_kernel_args_
;
simple_gemm_kernel_args_
.
reserve
(
arg
.
gemm_kernel_args_
.
size
());
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
std
::
vector
<
GemmKernelArgument
>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
arg
.
gemm_kernel_args_
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
stream_config
.
log_level_
>
0
)
{
karg
.
Print
();
}
auto
kbatch
=
karg
.
k_batch
;
std
::
cout
<<
"Group id: "
<<
i
<<
" block_size: "
<<
arg
.
gemm_kernel_args_
[
i
].
block_end_
-
arg
.
gemm_kernel_args_
[
i
].
block_start_
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
K0
=
karg
.
K0
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k0_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
if
(
karg
.
p_a_grid
==
nullptr
||
karg
.
p_b_grid
==
nullptr
||
karg
.
p_c_grid
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
"wrong! p_a/b/c_grid is nullptr"
);
}
simple
_gemm_kernel_args_
.
push_back
({
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
});
grouped
_gemm_kernel_args_
.
push_back
({
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
});
}
using
GemmArgumentType
=
SimpleGemm
Argument
;
using
GemmArgumentType
=
GemmKernel
Argument
;
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
simple
_gemm_kernel_args_
.
data
(),
simple
_gemm_kernel_args_
.
size
()
*
sizeof
(
GemmArgumentType
),
grouped
_gemm_kernel_args_
.
data
(),
grouped
_gemm_kernel_args_
.
size
()
*
sizeof
(
GemmArgumentType
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
fa649421
...
...
@@ -315,6 +315,11 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__device__
index_t
CalculateMLoops
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
}
private:
index_t
M01_
;
index_t
KSplit_
;
...
...
@@ -586,17 +591,22 @@ struct OffsettedBlockToCTileMap
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
index_t
block_start
,
index_t
mblock_id_off
=
0
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
mblock_id_off_
=
mblock_id_off
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}],
idx_bot
[
Number
<
1
>
{}]
+
mblock_id_off_
,
idx_bot
[
Number
<
2
>
{}]);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
...
@@ -620,6 +630,7 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
index_t
mblock_id_off_
;
};
/**
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
fa649421
...
...
@@ -621,9 +621,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
;
}
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]);
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
...
...
library/include/ck/library/utility/device_memory.hpp
View file @
fa649421
...
...
@@ -25,7 +25,9 @@ struct DeviceMem
void
*
GetDeviceBuffer
()
const
;
std
::
size_t
GetBufferSize
()
const
;
void
ToDevice
(
const
void
*
p
)
const
;
void
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
;
void
FromDevice
(
void
*
p
)
const
;
void
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
;
void
SetZero
()
const
;
template
<
typename
T
>
void
SetValue
(
T
x
)
const
;
...
...
library/src/utility/device_memory.cpp
View file @
fa649421
...
...
@@ -19,11 +19,21 @@ void DeviceMem::ToDevice(const void* p) const
hip_check_error
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
hip_check_error
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
{
hip_check_error
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
hip_check_error
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
void
DeviceMem
::
SetZero
()
const
{
hip_check_error
(
hipMemset
(
mpDeviceBuf
,
0
,
mMemSize
));
}
DeviceMem
::~
DeviceMem
()
{
hip_check_error
(
hipFree
(
mpDeviceBuf
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment