Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dfbb659a
Commit
dfbb659a
authored
Jul 26, 2022
by
Chao Liu
Browse files
upate contraction example
parent
809799bf
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
334 deletions
+128
-334
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
...or_operation/gpu/device/device_contraction_multiple_d.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
...gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
+119
-323
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+1
-3
No files found.
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
View file @
dfbb659a
...
...
@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_ms_
k
s_lengths
,
std
::
vector
<
index_t
>
a_ms_ks_strides
,
std
::
vector
<
index_t
>
b_ns_ks_lengths
,
std
::
vector
<
index_t
>
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
,
std
::
vector
<
index_t
>
e_ms_ns_lengths
,
std
::
vector
<
index_t
>
e_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>
&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
dfbb659a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dfbb659a
...
...
@@ -434,9 +434,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment