Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1043ab4f
Commit
1043ab4f
authored
Sep 02, 2021
by
ltqin
Browse files
trans a and b to gridegemm
parent
a088771c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
10 deletions
+41
-10
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+27
-4
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+14
-6
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
1043ab4f
...
@@ -18,6 +18,8 @@ template <typename GridwiseGemm,
...
@@ -18,6 +18,8 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
>
__global__
void
__global__
void
...
@@ -29,6 +31,8 @@ __global__ void
...
@@ -29,6 +31,8 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
...
@@ -43,6 +47,8 @@ __global__ void
...
@@ -43,6 +47,8 @@ __global__ void
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
}
}
...
@@ -52,6 +58,8 @@ template <typename GridwiseGemm,
...
@@ -52,6 +58,8 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
>
__global__
void
__global__
void
...
@@ -63,6 +71,8 @@ __global__ void
...
@@ -63,6 +71,8 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
{
{
...
@@ -73,6 +83,10 @@ __global__ void
...
@@ -73,6 +83,10 @@ __global__ void
cast_pointer_to_generic_address_space
(
p_a_k0_m_k1_grid_desc
));
cast_pointer_to_generic_address_space
(
p_a_k0_m_k1_grid_desc
));
const
auto
b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
(
const
auto
b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_k0_n_k1_grid_desc
));
cast_pointer_to_generic_address_space
(
p_b_k0_n_k1_grid_desc
));
const
auto
a_b_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
ABK0MK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_b_k0_m_k1_grid_desc
));
const
auto
b_b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BBK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_b_k0_n_k1_grid_desc
));
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2GridDesc
*>
(
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
));
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
));
...
@@ -87,6 +101,8 @@ __global__ void
...
@@ -87,6 +101,8 @@ __global__ void
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
}
}
...
@@ -311,6 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -311,6 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{},
I1
));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{},
I1
));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
...
@@ -320,6 +338,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -320,6 +338,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
...
@@ -330,12 +350,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -330,12 +350,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
kbatch
=
CalculateKBatch
(
CMNGridDesc
{},
b_k0_n_k1_grid_desc
);
if
(
get_block_1d_id
()
==
0
)
printf
(
"*****kbatch : %d"
,
kbatch
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
kbatch
=
CalculateKBatch
(
CMNGridDesc
{},
b_k0_n_k1_grid_desc
);
if
(
get_block_1d_id
()
==
0
)
printf
(
"*****kbatch : %d, %d, %d, %d
\n
"
,
kbatch
,
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
),
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
),
K0
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
1043ab4f
...
@@ -123,10 +123,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -123,10 +123,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
const
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
const
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
//
const auto a_b_k0_m_k1_grid_desc =
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
kbatch
);
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
kbatch
);
//
const auto b_b_k0_n_k1_grid_desc =
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
,
kbatch
);
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
,
kbatch
);
{
{
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
}
}
...
@@ -139,8 +139,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -139,8 +139,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
//
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using
ABK0MK1GridDesc
=
decltype
(
a_b_k0_m_k1_grid_desc
);
//
using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using
BBK0NK1GridDesc
=
decltype
(
b_b_k0_n_k1_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
...
@@ -158,6 +158,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -158,6 +158,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
FloatC
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
remove_reference_t
<
CBlockClusterAdaptor
>>
;
...
@@ -172,12 +174,16 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -172,12 +174,16 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
DeviceMem
b_b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BBK0NK1GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
...
@@ -197,6 +203,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -197,6 +203,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment