Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6d9425ec
Commit
6d9425ec
authored
Mar 16, 2022
by
Jianfeng yan
Browse files
added block2CTileMap, but results are not correct
parent
f4f94f70
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
10 deletions
+62
-10
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+55
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+7
-6
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
6d9425ec
...
...
@@ -7,6 +7,7 @@
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "statically_indexed_array.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
...
...
@@ -181,6 +182,51 @@ struct DeviceGroupedGemmXdl
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
template
<
int
GroupCount
>
struct
Block2CTileMap
{
Block2CTileMap
()
=
default
;
template
<
typename
GemmDesc
>
// Block2CTileMap(const StaticallyIndexedArray<GemmDesc, GroupCount>& gemm_desc)
Block2CTileMap
(
const
std
::
vector
<
GemmDesc
>&
gemm_desc
,
const
index_t
N0
)
:
N0_
{
N0
}
{
for
(
index_t
grp
=
0
;
grp
<
GroupCount
-
1
;
++
grp
)
{
assert
(
gemm_desc
[
grp
].
BlockEnd
==
gemm_desc
[
grp
+
1
].
BlockStart
);
}
for
(
index_t
grp
=
0
;
grp
<
GroupCount
;
++
grp
)
{
block_ptr_
[
grp
]
=
gemm_desc
[
grp
].
BlockStart
;
}
block_ptr_
[
GroupCount
]
=
gemm_desc
[
GroupCount
-
1
].
BlockEnd
;
}
template
<
typename
Index
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
Index
blockIdx
)
const
{
index_t
block_id
=
blockIdx
[
Number
<
0
>
{}];
index_t
local_block_id
;
for
(
index_t
grp
=
0
;
grp
<
MaxGroupCount
;
++
grp
)
{
if
(
block_id
>=
block_ptr_
[
grp
]
&&
block_id
<
block_ptr_
[
grp
+
1
])
{
local_block_id
=
block_id
-
block_ptr_
[
grp
];
}
}
return
make_tuple
(
local_block_id
/
N0_
,
local_block_id
%
N0_
);
// return make_tuple(local_block_id % N0_, local_block_id / N0_);
}
private:
index_t
block_ptr_
[
GroupCount
+
1
];
index_t
N0_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
...
@@ -232,13 +278,14 @@ struct DeviceGroupedGemmXdl
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
ck
::
index_t
BlockStart
,
BlockEnd
;
// typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap
<
MaxGroupCount
>
block_2_ctile_map_
;
};
// Argument
...
...
@@ -301,15 +348,13 @@ struct DeviceGroupedGemmXdl
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
const
auto
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
gemm_desc_kernel_arg_
.
push_back
(
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
block_2_ctile_map_
,
//
block_2_ctile_map_,
static_cast
<
const
ADataType
*>
(
p_a
[
i
]),
static_cast
<
const
BDataType
*>
(
p_b
[
i
]),
static_cast
<
CDataType
*>
(
p_c
[
i
]),
...
...
@@ -317,6 +362,11 @@ struct DeviceGroupedGemmXdl
BlockEnd
});
}
}
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
gemm_desc_kernel_arg_
[
i
].
block_2_ctile_map_
=
Block2CTileMap
<
MaxGroupCount
>
{
gemm_desc_kernel_arg_
,
gemm_desc_kernel_arg_
[
i
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
.
GetLength
(
Number
<
1
>
{})};
}
}
// private:
...
...
@@ -328,6 +378,7 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
c_element_op_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
// Block2CTileMap<MaxGroupCount> block_2_ctile_map_;
index_t
grid_size_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
6d9425ec
...
...
@@ -73,6 +73,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -84,7 +85,7 @@ __global__ void
i
<
group_count
)
{
auto
group_id
=
i
;
const
index_t
block_id_grp
=
block_id
-
gemm_desc_
[
group_id
].
BlockStart
;
//
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
...
...
@@ -97,8 +98,8 @@ __global__ void
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
block_2_ctile_map_
,
block_
id_gr
p
);
gemm_desc_
[
group_id
].
block_2_ctile_map_
);
//
block_
2_ctile_ma
p);
}
});
#else
...
...
@@ -426,8 +427,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
,
ck
::
index_t
block_id
=
get_block_1d_id
())
const
Block2CTileMap
&
block_2_ctile_map
)
//
ck::index_t block_id = get_block_1d_id())
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -440,7 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment