Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9be8900f
Commit
9be8900f
authored
May 05, 2023
by
Po-Yen, Chen
Browse files
Push-down class 'GridwiseGemm::Argument' fields
parent
7bae1691
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
160 additions
and
122 deletions
+160
-122
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+88
-76
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+54
-25
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+18
-21
No files found.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
9be8900f
...
...
@@ -451,26 +451,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
AK0_
,
index_t
BK0_
)
:
Parent
(
nullptr
,
nullptr
,
nullptr
,
M_
,
index_t
StrideC_
)
:
Parent
(
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded
_
,
NPadded
_
,
KPadded
_
,
AK0_
,
BK0_
),
GridwiseGemm
::
Calculate
MPadded
(
M_
)
,
GridwiseGemm
::
Calculate
NPadded
(
N_
)
,
GridwiseGemm
::
Calculate
KPadded
(
K_
)
,
GridwiseGemm
::
CalculateAK0
(
K_
)
,
GridwiseGemm
::
CalculateBK0
(
K_
)
),
p_a_grid_real_
{
p_a_grid_real
},
p_a_grid_imag_
{
p_a_grid_imag
},
p_b_grid_real_
{
p_b_grid_real
},
...
...
@@ -510,15 +502,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// Invoker
struct
Invoker
:
public
BaseInvoker
{
//
void Print(const Argument& karg) { karg.Print(); }
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
Argument
karg
=
arg
;
if
(
stream_config
.
log_level_
>
0
)
{
//
Print(karg);
Print
(
karg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
...
...
@@ -575,17 +565,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
true
>
;
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real_
,
karg
.
p_b_grid_real_
,
karg
.
p_aux_grid_
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag_
,
karg
.
p_b_grid_imag_
,
karg
.
p_aux_2_grid_
,
karg
);
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
...
...
@@ -601,17 +599,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
make_tuple
(
karg
.
p_c_grid_real_
),
Subtract
{});
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real_
,
karg
.
p_b_grid_imag_
,
karg
.
p_aux_grid_
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag_
,
karg
.
p_b_grid_real_
,
karg
.
p_aux_2_grid_
,
karg
);
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
...
...
@@ -631,17 +637,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
false
>
;
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real_
,
karg
.
p_b_grid_real_
,
karg
.
p_aux_grid_
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag_
,
karg
.
p_b_grid_imag_
,
karg
.
p_aux_2_grid_
,
karg
);
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
...
...
@@ -657,17 +671,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
make_tuple
(
karg
.
p_c_grid_real_
),
Subtract
{});
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real_
,
karg
.
p_b_grid_imag_
,
karg
.
p_aux_grid_
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag_
,
karg
.
p_b_grid_real_
,
karg
.
p_aux_2_grid_
,
karg
);
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
...
...
@@ -741,12 +763,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
K
,
StrideA
,
StrideB
,
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
)};
StrideC
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -782,12 +799,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
K
,
StrideA
,
StrideB
,
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
));
StrideC
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
9be8900f
...
...
@@ -130,7 +130,40 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched
,
PipelineVer
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
using
Parent
=
typename
GridwiseGemm
::
Argument
;
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Parent
(
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
GridwiseGemm
::
CalculateNPadded
(
N_
),
GridwiseGemm
::
CalculateKPadded
(
K_
),
GridwiseGemm
::
CalculateAK0
(
K_
),
GridwiseGemm
::
CalculateBK0
(
K_
)),
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
...
...
@@ -160,15 +193,29 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
);
}
return
ave_time
;
...
...
@@ -212,20 +259,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
)};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -252,12 +286,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K
,
StrideA
,
StrideB
,
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
));
StrideC
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
9be8900f
...
...
@@ -22,13 +22,17 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdl_cshuffle_v1_simplified
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v1_simplified
(
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
__restrict__
p_c_grid
,
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
karg
,
p_shared
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -37,10 +41,10 @@ __global__ void
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
FloatAB
,
typename
FloatAB
_
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
_
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -96,6 +100,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1_c
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1_c
=
Number
<
BK1Value
>
{};
using
FloatAB
=
FloatAB_
;
using
FloatC
=
FloatC_
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
#if defined(INTEGER_DIVIDE_CEIL)
...
...
@@ -390,10 +397,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
__host__
Argument
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
...
...
@@ -404,10 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t
KPadded_
,
index_t
AK0_
,
index_t
BK0_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
...
...
@@ -446,10 +447,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
__device__
~
Argument
()
override
{}
// private:
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
...
...
@@ -673,12 +670,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
Argument
&
karg
)
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
#define CREATE_DESCS_ON_HOST 1
#if CREATE_DESCS_ON_HOST
const
auto
a_grid_desc_ak0_m_ak1
=
karg
.
a_grid_desc_ak0_m_ak1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment