Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
64b9b6a0
Commit
64b9b6a0
authored
May 17, 2023
by
Po-Yen, Chen
Browse files
Separate 'Problem' concept out from 'Argument'
parent
468ffbd6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
156 deletions
+158
-156
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+66
-64
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+14
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+78
-41
No files found.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
64b9b6a0
...
@@ -168,9 +168,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -168,9 +168,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
// Argument
// Argument
struct
Argument
:
public
GridwiseGemm
::
Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
GridwiseGemm
::
Problem
{
{
using
P
arent
=
typename
GridwiseGemm
::
Argument
;
using
P
roblem
=
typename
GridwiseGemm
::
Problem
;
Argument
(
const
ADataType
*
p_a_grid_real_
,
Argument
(
const
ADataType
*
p_a_grid_real_
,
const
ADataType
*
p_a_grid_imag_
,
const
ADataType
*
p_a_grid_imag_
,
...
@@ -185,7 +185,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -185,7 +185,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
)
index_t
StrideC_
)
:
P
arent
(
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
)
,
:
P
roblem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
}
,
p_a_grid_real
{
p_a_grid_real_
},
p_a_grid_real
{
p_a_grid_real_
},
p_a_grid_imag
{
p_a_grid_imag_
},
p_a_grid_imag
{
p_a_grid_imag_
},
p_b_grid_real
{
p_b_grid_real_
},
p_b_grid_real
{
p_b_grid_real_
},
...
@@ -225,22 +225,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -225,22 +225,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
k
arg
.
Print
();
arg
.
Print
();
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
k
arg
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
k
arg
.
M
,
k
arg
.
N
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
k
arg
.
K
)
*
AK1
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
arg
.
K
)
*
AK1
;
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -284,27 +284,28 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -284,27 +284,28 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_real
,
arg
.
p_a_grid_real
,
k
arg
.
p_b_grid_real
,
arg
.
p_b_grid_real
,
k
arg
.
p_aux_grid
,
arg
.
p_aux_grid
,
k
arg
);
arg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_imag
,
arg
.
p_a_grid_imag
,
k
arg
.
p_b_grid_imag
,
arg
.
p_b_grid_imag
,
k
arg
.
p_aux_2_grid
,
arg
.
p_aux_2_grid
,
k
arg
);
arg
);
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
...
@@ -313,11 +314,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -313,11 +314,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
k
arg
.
c_grid_desc_m
,
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid
)),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
k
arg
.
p_c_grid_real
),
make_tuple
(
arg
.
p_c_grid_real
),
Subtract
{});
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -325,20 +326,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -325,20 +326,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_real
,
arg
.
p_a_grid_real
,
k
arg
.
p_b_grid_imag
,
arg
.
p_b_grid_imag
,
k
arg
.
p_aux_grid
,
arg
.
p_aux_grid
,
k
arg
);
arg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_imag
,
arg
.
p_a_grid_imag
,
k
arg
.
p_b_grid_real
,
arg
.
p_b_grid_real
,
k
arg
.
p_aux_2_grid
,
arg
.
p_aux_2_grid
,
k
arg
);
arg
);
// c_imag = aux + aux_2
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
...
@@ -347,36 +348,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -347,36 +348,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
k
arg
.
c_grid_desc_m
,
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid
)),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
k
arg
.
p_c_grid_imag
),
make_tuple
(
arg
.
p_c_grid_imag
),
Add
{});
Add
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_real
,
arg
.
p_a_grid_real
,
k
arg
.
p_b_grid_real
,
arg
.
p_b_grid_real
,
k
arg
.
p_aux_grid
,
arg
.
p_aux_grid
,
k
arg
);
arg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_imag
,
arg
.
p_a_grid_imag
,
k
arg
.
p_b_grid_imag
,
arg
.
p_b_grid_imag
,
k
arg
.
p_aux_2_grid
,
arg
.
p_aux_2_grid
,
k
arg
);
arg
);
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
...
@@ -385,11 +387,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -385,11 +387,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
k
arg
.
c_grid_desc_m
,
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid
)),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
k
arg
.
p_c_grid_real
),
make_tuple
(
arg
.
p_c_grid_real
),
Subtract
{});
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -397,20 +399,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -397,20 +399,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_real
,
arg
.
p_a_grid_real
,
k
arg
.
p_b_grid_imag
,
arg
.
p_b_grid_imag
,
k
arg
.
p_aux_grid
,
arg
.
p_aux_grid
,
k
arg
);
arg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
k
arg
.
p_a_grid_imag
,
arg
.
p_a_grid_imag
,
k
arg
.
p_b_grid_real
,
arg
.
p_b_grid_real
,
k
arg
.
p_aux_2_grid
,
arg
.
p_aux_2_grid
,
k
arg
);
arg
);
// c_imag = aux + aux_2
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
...
@@ -419,11 +421,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -419,11 +421,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
k
arg
.
c_grid_desc_m
,
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
k
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid
)),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
k
arg
.
p_c_grid_imag
),
make_tuple
(
arg
.
p_c_grid_imag
),
Add
{});
Add
{});
}
}
...
@@ -444,9 +446,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -444,9 +446,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
k
arg
);
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
64b9b6a0
...
@@ -130,80 +130,43 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -130,80 +130,43 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched
,
LoopSched
,
PipelineVer
>
;
PipelineVer
>
;
struct
Argument
:
public
GridwiseGemm
::
Argument
using
Argument
=
typename
GridwiseGemm
::
Argument
;
{
using
Parent
=
typename
GridwiseGemm
::
Argument
;
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Parent
(
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
),
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
k
arg
.
Print
();
arg
.
Print
();
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
k
arg
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
k
arg
.
M
,
k
arg
.
N
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
k
arg
.
K
)
*
AK1
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
arg
.
K
)
*
AK1
;
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
_simplified
<
GridwiseGemm
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
_simplified
<
GridwiseGemm
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -223,9 +186,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -223,9 +186,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
k
arg
);
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
64b9b6a0
...
@@ -22,32 +22,49 @@ __global__ void
...
@@ -22,32 +22,49 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdl_cshuffle_v1_simplified
(
kernel_gemm_xdl_cshuffle_v1
(
typename
GridwiseGemm
::
Argument
karg
)
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
__restrict__
p_c_grid
,
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
karg
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdl_cshuffle_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
typename
GridwiseGemm
::
Problem
problem
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
problem
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
karg
;
ignore
=
problem
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
typename
FloatAB
_
,
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
_
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -103,9 +120,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -103,9 +120,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
using
FloatAB
=
FloatAB_
;
using
FloatC
=
FloatC_
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
@@ -389,15 +403,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -389,15 +403,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
}
}
// Argument
struct
Problem
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
{
{
__host__
Argument
(
index_t
M_
,
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
)
index_t
StrideC_
)
:
M
{
M_
},
:
M
{
M_
},
N
{
N_
},
N
{
N_
},
K
{
K_
},
K
{
K_
},
...
@@ -416,7 +429,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -416,7 +429,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
void
Print
()
const
__host__
void
Print
()
const
{
{
std
::
cout
<<
"
arg
{"
std
::
cout
<<
"
problem
{"
<<
"M:"
<<
M
<<
", "
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"K:"
<<
K
<<
", "
...
@@ -447,6 +460,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -447,6 +460,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t
NBlock
;
index_t
NBlock
;
};
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
@@ -510,7 +547,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -510,7 +547,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -521,7 +558,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -521,7 +558,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
problem
.
M
%
MPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -532,7 +569,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -532,7 +569,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
problem
.
N
%
NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -543,15 +580,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -543,15 +580,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
{
if
(
!
(
CalculateKPadded
(
karg
.
K
)
%
AK1Value
==
0
)
||
if
(
!
(
CalculateKPadded
(
problem
.
K
)
%
AK1Value
==
0
)
||
!
(
CalculateKPadded
(
karg
.
K
)
%
BK1Value
==
0
))
!
(
CalculateKPadded
(
problem
.
K
)
%
BK1Value
==
0
))
{
{
return
false
;
return
false
;
}
}
}
}
else
else
{
{
if
(
!
(
karg
.
K
%
AK1Value
==
0
)
||
!
(
karg
.
K
%
BK1Value
==
0
))
if
(
!
(
problem
.
K
%
AK1Value
==
0
)
||
!
(
problem
.
K
%
BK1Value
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -559,14 +596,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -559,14 +596,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
else
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -574,14 +611,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -574,14 +611,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
else
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -589,21 +626,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -589,21 +626,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
problem
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
else
{
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
problem
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
// check gridwise gemm pipeline
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
CalculateAK0
(
karg
.
K
)
*
AK1Value
)
/
KPerBlock
;
const
auto
num_k_loop
=
(
CalculateAK0
(
problem
.
K
)
*
AK1Value
)
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
{
...
@@ -643,18 +680,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -643,18 +680,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
Argument
&
karg
)
const
Problem
&
problem
)
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
KPadded
,
karg
.
StrideA
,
karg
.
AK0
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
karg
.
K
,
karg
.
KPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
,
karg
.
BK0
);
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
karg
.
MBlock
,
karg
.
NBlock
);
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -668,7 +705,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -668,7 +705,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
};
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
};
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment