Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
83cb77b3
Commit
83cb77b3
authored
Dec 04, 2023
by
Bartlomiej Wroblewski
Browse files
Use utils in the multi ABD contraction
parent
c126681c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
88 deletions
+90
-88
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+69
-66
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+21
-22
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
83cb77b3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
{
a_mz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
];
as_mz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
]
==
1
;
a_kz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
];
as_kz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
]
==
1
;
as_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
}
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
{
b_nz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
];
bs_nz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
]
==
1
;
b_kz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
];
bs_kz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
]
==
1
;
bs_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
}
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_stride_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
ds_nz_consecutive_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
}
}
e_nz_stride_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
];
e_nz_consecutive_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
}
}
// pointers
// pointers
...
@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store
// in the memory or not.
std
::
array
<
index_t
,
NumATensor
>
a_mz_stride_
;
std
::
array
<
bool
,
NumATensor
>
as_mz_consecutive_
;
std
::
array
<
index_t
,
NumATensor
>
a_kz_stride_
;
std
::
array
<
bool
,
NumATensor
>
as_kz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_nz_consecutive_
;
std
::
array
<
index_t
,
NumBTensor
>
b_nz_stride_
;
std
::
array
<
bool
,
NumBTensor
>
bs_kz_consecutive_
;
std
::
array
<
index_t
,
NumBTensor
>
b_kz_stride_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_nz_stride_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_max_read_elems_
;
index_t
e_max_write_elems_
;
};
};
// Invoker
// Invoker
...
@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// check vector load/store
// check vector load/store
{
{
bool
all_valid
=
true
;
bool
valid_as_access
=
true
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// vector memory access of A: could be on M or AK1 dimension
const
bool
valid_a_vector_size
=
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
{
const
bool
valid_a_access_dim_m
=
if
(
!
(
arg
.
a_mz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I1
)
%
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_mz_consecutive_
[
i
];
ABlockTransferSrcScalarPerVector
==
const
bool
valid_a_access_dim_k
=
0
))
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_kz_consecutive_
[
i
];
{
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
all_valid
=
false
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
}
}
else
{
{
if
(
!
(
arg
.
a_kz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I2
)
%
valid_as_access
=
false
;
ABlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
}
});
});
if
(
!
valid_as_access
)
{
return
false
;
}
// vector memory access of B: could be on N or BK1 dimension
bool
valid_bs_access
=
true
;
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
const
bool
valid_b_vector_size
=
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_nz_consecutive_
[
i
];
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_kz_consecutive_
[
i
];
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
{
{
if
(
!
(
arg
.
b_nz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I1
)
%
valid_bs_access
=
false
;
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
}
});
});
if
(
!
valid_bs_access
)
{
return
false
;
}
// check vector load of Ds
bool
valid_ds_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
const
bool
valid_d_vector_size
=
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
CDEBlockTransferScalarPerVector_NPerBlock
==
// Vector read of Ds is always on N dimension.
0
))
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
{
all_
valid
=
false
;
valid
_ds_access
=
false
;
}
}
});
});
if
(
!
valid_ds_access
)
// vector memory access of E: always on NPerBlock dimension
if
(
!
(
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
all_valid
=
false
;
return
false
;
}
}
if
(
!
all_valid
)
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
83cb77b3
...
@@ -11,9 +11,9 @@
...
@@ -11,9 +11,9 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -184,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -184,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
};
const
auto
a_ms_
n
s_lengths
=
to_tuple
(
a_ms_ks_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_
k
s_lengths
=
to_tuple
(
a_ms_ks_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
// dimension Ids for M0, M1, ...
...
@@ -195,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -195,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_
n
s_lengths
,
mDimIds
);
const
auto
mLengths
=
get_container_subset
(
a_ms_
k
s_lengths
,
mDimIds
);
// lengths for K0, K1, ...
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_
n
s_lengths
,
kDimIds
);
const
auto
kLengths
=
get_container_subset
(
a_ms_
k
s_lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_
n
s_lengths
,
a_ms_ks_strides
);
make_naive_tensor_descriptor
(
a_ms_
k
s_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
...
@@ -384,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -384,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -399,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -399,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_ms_
n
s_lengths
,
a_ms_ks_strides
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_ms_
k
s_lengths
,
a_ms_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
...
@@ -445,25 +445,24 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -445,25 +445,24 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
// for sanity check of vector memory access
// for sanity check of vector memory access
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_max_read_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
a_max_read_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ns_lengths
,
a_ms_ks_strides
);
b_max_read_elems_
=
b_max_read_elems_
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDim
K
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
CalculateMaxRead
<
NumDimM
,
NumDim
N
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDim
K
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
CalculateMaxRead
<
NumDimM
,
NumDim
N
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
}
void
Print
()
const
void
Print
()
const
...
@@ -655,7 +654,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -655,7 +654,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
valid_ds_access
=
false
;
valid_ds_access
=
false
;
}
}
});
});
if
(
valid_ds_access
==
false
)
if
(
!
valid_ds_access
)
{
{
return
false
;
return
false
;
}
}
...
@@ -682,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -682,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -698,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -698,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_ms_
n
s_lengths
,
a_ms_
k
s_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
b_ns_ks_strides
,
...
@@ -719,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -719,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -735,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -735,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_ms_
n
s_lengths
,
a_ms_
k
s_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
b_ns_ks_strides
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment