Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
933951ed
Unverified
Commit
933951ed
authored
Jun 18, 2024
by
Bartłomiej Kocot
Committed by
GitHub
Jun 18, 2024
Browse files
Fix continous dim selection in contraction (#1336)
* Fix continous dim selection in contraction * Fixes
parent
17ed368f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
59 deletions
+80
-59
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+17
-25
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+19
-24
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+38
-10
test/contraction/test_contraction_xdl.cpp
test/contraction/test_contraction_xdl.cpp
+6
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
933951ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
{
as_mz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
]
==
1
;
tie
(
as_continous_dim_
[
i
],
as_max_read_elems_
[
i
])
=
as_kz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
]
==
1
;
as_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
}
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
{
bs_nz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
]
==
1
;
tie
(
bs_continous_dim_
[
i
],
bs_max_read_elems_
[
i
])
=
bs_kz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
]
==
1
;
bs_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
}
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_consecutive_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
}
}
e_nz_consecutive_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
]
==
1
;
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
}
}
// pointers
// pointers
...
@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not.
std
::
array
<
index_t
,
NumATensor
>
as_continous_dim_
;
std
::
array
<
bool
,
NumATensor
>
as_mz_consecutive_
;
std
::
array
<
index_t
,
NumATensor
>
bs_continous_dim_
;
std
::
array
<
bool
,
NumATensor
>
as_kz_consecutive_
;
std
::
array
<
index_t
,
NumBTensor
>
ds_continous_dim_
;
std
::
array
<
bool
,
NumBTensor
>
bs_nz_consecutive_
;
index_t
e_continous_dim_
;
std
::
array
<
bool
,
NumBTensor
>
bs_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
...
@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
const
bool
valid_a_vector_size
=
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
mz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
continous_dim_
[
i
]
==
0
;
const
bool
valid_a_access_dim_k
=
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
kz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
continous_dim_
[
i
]
==
1
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
((
valid_a_vector_size
&&
valid_a_access_dim
)
||
if
(
!
((
valid_a_vector_size
&&
valid_a_access_dim
)
||
ABlockTransferSrcScalarPerVector
==
1
))
ABlockTransferSrcScalarPerVector
==
1
))
...
@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
const
bool
valid_b_vector_size
=
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
nz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
continous_dim_
[
i
]
==
0
;
const
bool
valid_b_access_dim_k
=
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
kz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
continous_dim_
[
i
]
==
1
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
((
valid_b_vector_size
&&
valid_b_access_dim
)
||
if
(
!
((
valid_b_vector_size
&&
valid_b_access_dim
)
||
BBlockTransferSrcScalarPerVector
==
1
))
BBlockTransferSrcScalarPerVector
==
1
))
...
@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_d_vector_size
=
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
;
const
bool
valid_d_access_dim
=
arg
.
ds_
continous_dim_
[
i
]
==
1
;
if
(
!
((
valid_d_vector_size
&&
valid_d_access_dim
)
||
if
(
!
((
valid_d_vector_size
&&
valid_d_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
{
...
@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_e_vector_size
=
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
;
const
bool
valid_e_access_dim
=
arg
.
e_
continous_dim_
==
1
;
if
(
!
((
valid_e_vector_size
&&
valid_e_access_dim
)
||
if
(
!
((
valid_e_vector_size
&&
valid_e_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
933951ed
...
@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
// for sanity check of vector memory access
// for sanity check of vector memory access
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
tie
(
a_continous_dim_
,
a_max_read_elems_
)
=
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_max_read_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
tie
(
b_continous_dim_
,
b_max_read_elems_
)
=
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_max_read_elems_
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
}
...
@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not.
index_t
a_continous_dim_
;
bool
a_mz_consecutive_
;
index_t
b_continous_dim_
;
bool
a_kz_consecutive_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_continous_dim_
;
bool
b_nz_consecutive_
;
index_t
e_continous_dim_
;
bool
b_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
index_t
a_max_read_elems_
;
index_t
a_max_read_elems_
;
index_t
b_max_read_elems_
;
index_t
b_max_read_elems_
;
...
@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
const
bool
valid_a_vector_size
=
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_m
=
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_continous_dim_
==
0
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_continous_dim_
==
1
;
const
bool
valid_a_access_dim
=
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
...
@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
const
bool
valid_b_vector_size
=
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
const
bool
valid_b_access_dim_n
=
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_continous_dim_
==
0
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_continous_dim_
==
1
;
const
bool
valid_b_access_dim
=
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
...
@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
ds_
continous_dim_
[
i
]
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
{
valid_ds_access
=
false
;
valid_ds_access
=
false
;
...
@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
e_
continous_dim_
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
{
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
933951ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
...
@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
// Determine the beginning and end idx of the group representing the FCD.
// Determine the beginning and end idx of the group representing the FCD.
index_t
begin_idx
,
end_idx
;
index_t
begin_idx
,
end_idx
,
continous_dim
,
consecutive_stride
=
1
;
if
(
strides
[
NumDim1
-
1
]
==
1
)
if
(
strides
[
NumDim1
-
1
]
==
1
&&
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
{
begin_idx
=
0
;
// MZ or KZ are ones
end_idx
=
NumDim1
-
1
;
bool
dims1_are_ones
=
true
;
for
(
index_t
dim_idx
=
0
;
dim_idx
<
NumDim1
;
dim_idx
++
)
{
if
(
lengths
[
dim_idx
]
!=
1
)
{
dims1_are_ones
=
false
;
}
}
if
(
dims1_are_ones
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
else
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
}
else
if
(
strides
[
NumDim1
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
}
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
{
begin_idx
=
NumDim1
;
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
}
else
else
{
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
// one element can be read/written at once.
return
1
;
consecutive_stride
=
1
;
continous_dim
=
0
;
return
make_tuple
(
continous_dim
,
consecutive_stride
);
}
}
index_t
consecutive_stride
=
1
;
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
{
{
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
...
@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
...
@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
}
}
const
index_t
max_subsequent_elems
=
consecutive_stride
;
const
index_t
max_subsequent_elems
=
consecutive_stride
;
return
max_subsequent_elems
;
return
make_tuple
(
continous_dim
,
max_subsequent_elems
)
;
}
}
}
// namespace device
}
// namespace device
...
...
test/contraction/test_contraction_xdl.cpp
View file @
933951ed
...
@@ -212,4 +212,10 @@ TYPED_TEST(TestContractionScaleMixedPrecision, scale)
...
@@ -212,4 +212,10 @@ TYPED_TEST(TestContractionScaleMixedPrecision, scale)
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
// special cases
this
->
template
Run
<
2
>({{
1
,
1
},
{
16
,
8
},
{
8
,
16
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
1
,
1
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
1
,
1
},
{
8
,
16
}});
this
->
template
Run
<
2
>({{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment