Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bac321a9
Commit
bac321a9
authored
May 05, 2023
by
Po-Yen, Chen
Browse files
Remove 'block_2_ctile_map' kernel parameter
parent
a65d6459
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
47 deletions
+31
-47
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+22
-35
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+9
-12
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
bac321a9
...
@@ -243,9 +243,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -243,9 +243,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
)
index_t
M01
,
index_t
N01
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -253,16 +251,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -253,16 +251,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
kraw_
{
K
}
kraw_
{
K
}
{
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
// private:
// private:
...
@@ -273,8 +268,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -273,8 +268,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
index_t
kraw_
;
index_t
kraw_
;
};
};
...
@@ -319,14 +312,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -319,14 +312,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
GridwiseGemm
,
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M_N
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M_N
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
...
@@ -339,19 +331,17 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -339,19 +331,17 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
GridwiseGemm
,
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M_N
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M_N
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
...
@@ -364,8 +354,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -364,8 +354,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -438,7 +427,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -438,7 +427,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
CElementwiseOperation
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -465,9 +454,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -465,9 +454,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
);
1
,
1
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
bac321a9
...
@@ -22,7 +22,6 @@ template <typename GridwiseGemm,
...
@@ -22,7 +22,6 @@ template <typename GridwiseGemm,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -33,8 +32,7 @@ __global__ void
...
@@ -33,8 +32,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
c_grid_desc_m_n
,
const
CGridDesc_M_N
c_grid_desc_m_n
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__))
...
@@ -46,8 +44,7 @@ __global__ void
...
@@ -46,8 +44,7 @@ __global__ void
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
c_grid_desc_m_n
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -55,7 +52,6 @@ __global__ void
...
@@ -55,7 +52,6 @@ __global__ void
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m_n
;
ignore
=
c_grid_desc_m_n
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -293,8 +289,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -293,8 +289,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
__host__
__device__
static
constexpr
auto
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
...
@@ -302,17 +298,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -302,17 +298,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}
,
1
,
1
));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
@@ -330,6 +325,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -330,6 +325,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
block_2_ctile_map
=
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment