Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4d70c71b
Commit
4d70c71b
authored
Sep 30, 2020
by
Chao Liu
Browse files
refactor array
parent
5a2498d1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
149 additions
and
158 deletions
+149
-158
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
...l/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
+73
-68
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
.../include/tensor_description/dynamic_tensor_descriptor.hpp
+17
-23
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp
...e/tensor_description/dynamic_tensor_descriptor_helper.hpp
+2
-2
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
...clude/tensor_description/dynamic_tensor_descriptor_v2.hpp
+16
-13
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+0
-12
composable_kernel/include/utility/array_element_picker.hpp
composable_kernel/include/utility/array_element_picker.hpp
+6
-0
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+21
-30
composable_kernel/include/utility/print.hpp
composable_kernel/include/utility/print.hpp
+3
-2
driver/include/device_dummy_dynamic_transform_v1.hpp
driver/include/device_dummy_dynamic_transform_v1.hpp
+8
-8
driver/include/device_dummy_dynamic_transform_v2.hpp
driver/include/device_dummy_dynamic_transform_v2.hpp
+3
-0
No files found.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
View file @
4d70c71b
...
@@ -13,39 +13,39 @@ __host__ __device__ constexpr auto
...
@@ -13,39 +13,39 @@ __host__ __device__ constexpr auto
map_convolution_into_gemm_v1
(
const
WeiDesc
&
wei_k_c_y_x_global_desc
,
map_convolution_into_gemm_v1
(
const
WeiDesc
&
wei_k_c_y_x_global_desc
,
const
InDesc
&
in_n_c_hi_wi_global_desc
,
const
InDesc
&
in_n_c_hi_wi_global_desc
,
const
OutDesc
&
out_n_k_ho_wo_global_desc
,
const
OutDesc
&
out_n_k_ho_wo_global_desc
,
const
Array
<
index_t
,
2
>
conv_strides
,
const
MultiIndex
<
2
>
&
conv_strides
,
const
Array
<
index_t
,
2
>
conv_dilations
,
const
MultiIndex
<
2
>
&
conv_dilations
,
const
Array
<
index_t
,
2
>
in_left_pads
,
const
MultiIndex
<
2
>
&
in_left_pads
,
const
Array
<
index_t
,
2
>
in_right_pads
)
const
MultiIndex
<
2
>
&
in_right_pads
)
{
{
constexpr
auto
i
0
=
Number
<
0
>
{};
constexpr
auto
I
0
=
Number
<
0
>
{};
constexpr
auto
i
1
=
Number
<
1
>
{};
constexpr
auto
I
1
=
Number
<
1
>
{};
constexpr
auto
i
2
=
Number
<
2
>
{};
constexpr
auto
I
2
=
Number
<
2
>
{};
constexpr
auto
i
3
=
Number
<
3
>
{};
constexpr
auto
I
3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i
0
);
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i
1
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i
1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
1
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
i
2
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I
2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
i
3
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I
3
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i
2
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
i
3
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i
2
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
i
3
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
3
);
const
index_t
ConvStrideH
=
conv_strides
[
i
0
];
const
index_t
ConvStrideH
=
conv_strides
[
I
0
];
const
index_t
ConvStrideW
=
conv_strides
[
i
1
];
const
index_t
ConvStrideW
=
conv_strides
[
I
1
];
const
index_t
ConvDilationH
=
conv_dilations
[
i
0
];
const
index_t
ConvDilationH
=
conv_dilations
[
I
0
];
const
index_t
ConvDilationW
=
conv_dilations
[
i
1
];
const
index_t
ConvDilationW
=
conv_dilations
[
I
1
];
const
index_t
InLeftPadH
=
in_left_pads
[
i
0
];
const
index_t
InLeftPadH
=
in_left_pads
[
I
0
];
const
index_t
InLeftPadW
=
in_left_pads
[
i
1
];
const
index_t
InLeftPadW
=
in_left_pads
[
I
1
];
const
index_t
InRightPadH
=
in_right_pads
[
i
0
];
const
index_t
InRightPadH
=
in_right_pads
[
I
0
];
const
index_t
InRightPadW
=
in_right_pads
[
i
1
];
const
index_t
InRightPadW
=
in_right_pads
[
I
1
];
// input tensor
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -64,8 +64,8 @@ map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc,
...
@@ -64,8 +64,8 @@ map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
i
2
);
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I
2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
i
3
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I
3
);
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
...
@@ -97,55 +97,60 @@ struct DummyDynamicTransform_v1
...
@@ -97,55 +97,60 @@ struct DummyDynamicTransform_v1
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
Array
<
index_t
,
2
>
conv_strides
,
const
MultiIndex
<
2
>
&
conv_strides
,
const
Array
<
index_t
,
2
>
conv_dilations
,
const
MultiIndex
<
2
>
&
conv_dilations
,
const
Array
<
index_t
,
2
>
in_left_pads
,
const
MultiIndex
<
2
>
&
in_left_pads
,
const
Array
<
index_t
,
2
>
in_right_pads
)
const
const
MultiIndex
<
2
>
&
in_right_pads
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
#if 1
#if 1
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
0
);
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
1
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
1
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
2
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I
2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
3
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I
3
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
2
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
3
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
2
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
3
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I
3
);
const
index_t
ConvStrideH
=
conv_strides
[
0
];
const
index_t
ConvStrideH
=
conv_strides
[
I
0
];
const
index_t
ConvStrideW
=
conv_strides
[
1
];
const
index_t
ConvStrideW
=
conv_strides
[
I
1
];
const
index_t
ConvDilationH
=
conv_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_dilations
[
I
0
];
const
index_t
ConvDilationW
=
conv_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_dilations
[
I
1
];
const
index_t
InLeftPadH
=
in_left_pads
[
0
];
const
index_t
InLeftPadH
=
in_left_pads
[
I
0
];
const
index_t
InLeftPadW
=
in_left_pads
[
1
];
const
index_t
InLeftPadW
=
in_left_pads
[
I
1
];
const
index_t
InRightPadH
=
in_right_pads
[
0
];
const
index_t
InRightPadH
=
in_right_pads
[
I
0
];
const
index_t
InRightPadW
=
in_right_pads
[
1
];
const
index_t
InRightPadW
=
in_right_pads
[
I
1
];
#else
#else
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
0
);
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
1
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
1
);
const
index_t
Y
=
3
;
const
index_t
Y
=
3
;
const
index_t
X
=
3
;
const
index_t
X
=
3
;
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
2
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
3
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I
3
);
const
index_t
ConvStrideH
=
conv_strides
[
0
];
const
index_t
ConvStrideH
=
conv_strides
[
I
0
];
const
index_t
ConvStrideW
=
conv_strides
[
1
];
const
index_t
ConvStrideW
=
conv_strides
[
I
1
];
const
index_t
ConvDilationH
=
conv_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_dilations
[
I
0
];
const
index_t
ConvDilationW
=
conv_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_dilations
[
I
1
];
const
index_t
InLeftPadH
=
in_left_pads
[
0
];
const
index_t
InLeftPadH
=
in_left_pads
[
I
0
];
const
index_t
InLeftPadW
=
in_left_pads
[
1
];
const
index_t
InLeftPadW
=
in_left_pads
[
I
1
];
const
index_t
InRightPadH
=
in_right_pads
[
0
];
const
index_t
InRightPadH
=
in_right_pads
[
I
0
];
const
index_t
InRightPadW
=
in_right_pads
[
1
];
const
index_t
InRightPadW
=
in_right_pads
[
I
1
];
#endif
#endif
// define transform
// define transform
...
@@ -537,10 +542,10 @@ struct DummyDynamicTransform_v1
...
@@ -537,10 +542,10 @@ struct DummyDynamicTransform_v1
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
Array
<
index_t
,
2
>
conv_strides
,
const
MultiIndex
<
2
>
&
conv_strides
,
const
Array
<
index_t
,
2
>
conv_dilations
,
const
MultiIndex
<
2
>
&
conv_dilations
,
const
Array
<
index_t
,
2
>
in_left_pads
,
const
MultiIndex
<
2
>
&
in_left_pads
,
const
Array
<
index_t
,
2
>
in_right_pads
)
const
const
MultiIndex
<
2
>
&
in_right_pads
)
const
{
{
const
auto
transformed_tensor_descs
=
const
auto
transformed_tensor_descs
=
map_convolution_into_gemm_v1
(
wei_k_c_y_x_global_desc
,
map_convolution_into_gemm_v1
(
wei_k_c_y_x_global_desc
,
...
@@ -598,10 +603,10 @@ struct DummyDynamicTransform_v1
...
@@ -598,10 +603,10 @@ struct DummyDynamicTransform_v1
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
WeiDesc
wei_k_c_y_x_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
InDesc
in_n_c_hi_wi_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
OutDesc
out_n_k_ho_wo_global_desc
,
const
Array
<
index_t
,
2
>
conv_strides
,
const
MultiIndex
<
2
>
&
conv_strides
,
const
Array
<
index_t
,
2
>
conv_dilations
,
const
MultiIndex
<
2
>
&
conv_dilations
,
const
Array
<
index_t
,
2
>
in_left_pads
,
const
MultiIndex
<
2
>
&
in_left_pads
,
const
Array
<
index_t
,
2
>
in_right_pads
)
const
const
MultiIndex
<
2
>
&
in_right_pads
)
const
{
{
Run_2
(
p_wei_global
,
Run_2
(
p_wei_global
,
p_in_global
,
p_in_global
,
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
View file @
4d70c71b
...
@@ -31,9 +31,17 @@ struct DynamicNativeTensorDescriptor
...
@@ -31,9 +31,17 @@ struct DynamicNativeTensorDescriptor
__host__
__device__
constexpr
auto
GetStrides
()
const
{
return
strides_
;
}
__host__
__device__
constexpr
auto
GetStrides
()
const
{
return
strides_
;
}
__host__
__device__
constexpr
index_t
GetLength
(
index_t
idim
)
const
{
return
lengths_
[
idim
];
}
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
const
{
return
lengths_
[
Number
<
IDim
>
{}];
}
__host__
__device__
constexpr
index_t
GetStride
(
index_t
idim
)
const
{
return
strides_
[
idim
];
}
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetStride
(
Number
<
IDim
>
)
const
{
return
strides_
[
Number
<
IDim
>
{}];
}
__host__
__device__
constexpr
index_t
GetElementSize
()
const
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
{
...
@@ -44,11 +52,7 @@ struct DynamicNativeTensorDescriptor
...
@@ -44,11 +52,7 @@ struct DynamicNativeTensorDescriptor
{
{
index_t
space
=
1
;
index_t
space
=
1
;
#pragma unroll
static_for
<
0
,
NDim
,
1
>
{}([
&
](
auto
i
)
{
space
+=
(
GetLength
(
i
)
-
1
)
*
GetStride
(
i
);
});
for
(
index_t
i
=
0
;
i
<
NDim
;
++
i
)
{
space
+=
(
GetLength
(
i
)
-
1
)
*
GetStride
(
i
);
}
return
space
;
return
space
;
}
}
...
@@ -58,11 +62,7 @@ struct DynamicNativeTensorDescriptor
...
@@ -58,11 +62,7 @@ struct DynamicNativeTensorDescriptor
{
{
index_t
offset
=
0
;
index_t
offset
=
0
;
#pragma unroll
static_for
<
0
,
NDim
,
1
>
{}([
&
](
auto
i
)
{
offset
+=
idx
[
i
]
*
GetStride
(
i
);
});
for
(
index_t
i
=
0
;
i
<
NDim
;
++
i
)
{
offset
+=
idx
[
i
]
*
GetStride
(
i
);
}
return
offset
;
return
offset
;
}
}
...
@@ -78,11 +78,8 @@ struct DynamicNativeTensorDescriptor
...
@@ -78,11 +78,8 @@ struct DynamicNativeTensorDescriptor
{
{
bool
flag
=
true
;
bool
flag
=
true
;
#pragma unroll
static_for
<
0
,
NDim
,
1
>
{}(
for
(
index_t
i
=
0
;
i
<
NDim
;
++
i
)
[
&
](
auto
i
)
{
flag
=
flag
&&
idx
[
i
]
>=
0
&&
idx
[
i
]
<
GetLength
(
i
);
});
{
flag
=
flag
&&
idx
[
i
]
>=
0
&&
idx
[
i
]
<
GetLength
(
i
);
}
return
flag
;
return
flag
;
}
}
...
@@ -139,7 +136,7 @@ struct DynamicTransformedTensorDescriptor
...
@@ -139,7 +136,7 @@ struct DynamicTransformedTensorDescriptor
template
<
typename
...
Xs
>
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
operator
()(
Xs
...
xs
)
const
__host__
__device__
constexpr
auto
operator
()(
Xs
...
xs
)
const
{
{
return
merge_
array
s
(
xs
...);
return
array
_cat
(
xs
...);
}
}
};
};
...
@@ -306,11 +303,8 @@ struct DynamicTransformedTensorDescriptor
...
@@ -306,11 +303,8 @@ struct DynamicTransformedTensorDescriptor
{
{
bool
flag
=
true
;
bool
flag
=
true
;
#pragma unroll
static_for
<
0
,
NDimUp
,
1
>
{}(
for
(
index_t
i
=
0
;
i
<
NDimUp
;
++
i
)
[
&
](
auto
i
)
{
flag
=
flag
&&
idx_up
[
i
]
>=
0
&&
idx_up
[
i
]
<
GetLength
(
i
);
});
{
flag
=
flag
&&
idx_up
[
i
]
>=
0
&&
idx_up
[
i
]
<
GetLength
(
i
);
}
return
flag
;
return
flag
;
}
}
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp
View file @
4d70c71b
...
@@ -10,9 +10,9 @@ template <typename Lengths, typename Strides>
...
@@ -10,9 +10,9 @@ template <typename Lengths, typename Strides>
__host__
__device__
constexpr
auto
make_dynamic_native_tensor_descriptor
(
const
Lengths
&
lengths
,
__host__
__device__
constexpr
auto
make_dynamic_native_tensor_descriptor
(
const
Lengths
&
lengths
,
const
Strides
&
strides
)
const
Strides
&
strides
)
{
{
static_assert
(
Lengths
::
Get
Size
()
==
Strides
::
Get
Size
(),
"wrong! Size not the same"
);
static_assert
(
Lengths
::
Size
()
==
Strides
::
Size
(),
"wrong! Size not the same"
);
return
DynamicNativeTensorDescriptor
<
Lengths
::
Get
Size
()
>
(
lengths
,
strides
);
return
DynamicNativeTensorDescriptor
<
Lengths
::
Size
()
>
(
lengths
,
strides
);
}
}
template
<
typename
LowTensorDescriptor
,
template
<
typename
LowTensorDescriptor
,
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
View file @
4d70c71b
...
@@ -340,7 +340,7 @@ struct DynamicTensorCoordinateStep_v2
...
@@ -340,7 +340,7 @@ struct DynamicTensorCoordinateStep_v2
#endif
#endif
};
};
// TODO:
Fix this! This is insane, to
use an
ugly
struct instead of lambda because lambda
// TODO:
How to fix this? It
use
s
an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_dynamic_tensor_descriptor_v2) because template cannot be defined inside a function
// (transform_dynamic_tensor_descriptor_v2) because template cannot be defined inside a function
// template
// template
...
@@ -538,22 +538,25 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
...
@@ -538,22 +538,25 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
// update rest of hidden index
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
tensor_desc
,
&
idx_hidden
,
&
idx_diff_hidden
](
auto
itran
)
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
if
(
coord_step
.
do_transforms_
[
itran
])
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
{
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
// this const is for ArrayElementPicker, Array itself may not be const
// this const is for ArrayElementPicker, Array itself may not be const
const
auto
idx_up
=
pick_array_element
(
idx_hidden
,
dims_up
);
const
auto
idx_up
=
pick_array_element
(
idx_hidden
,
dims_up
);
auto
idx_low
=
pick_array_element
(
idx_hidden
,
dims_low
);
auto
idx_low
=
pick_array_element
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
pick_array_element
(
idx_diff_hidden
,
dims_up
);
const
auto
idx_diff_up
=
pick_array_element
(
idx_diff_hidden
,
dims_up
);
auto
idx_diff_low
=
pick_array_element
(
idx_diff_hidden
,
dims_low
);
auto
idx_diff_low
=
pick_array_element
(
idx_diff_hidden
,
dims_low
);
tran
.
CalculateLowerIndexDiff
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
tran
.
CalculateLowerIndexDiff
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
// update idx_low
// update idx_low
idx_low
+=
idx_diff_low
;
idx_low
+=
idx_diff_low
;
}
});
});
}
}
...
...
composable_kernel/include/utility/array.hpp
View file @
4d70c71b
...
@@ -59,17 +59,5 @@ __host__ __device__ constexpr auto make_array()
...
@@ -59,17 +59,5 @@ __host__ __device__ constexpr auto make_array()
return
Array
<
X
,
0
>
{};
return
Array
<
X
,
0
>
{};
}
}
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
push_back
(
Array
<
TData
,
NSize
>&
a
,
const
TData
&
x
)
{
Array
<
TData
,
NSize
+
1
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
r
,
&
a
](
auto
i
)
constexpr
{
r
(
i
)
=
a
[
i
];
});
r
(
Number
<
NSize
>
{})
=
x
;
return
r
;
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/array_element_picker.hpp
View file @
4d70c71b
...
@@ -97,5 +97,11 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
...
@@ -97,5 +97,11 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
return
y
;
return
y
;
}
}
template
<
typename
Arr
,
typename
Picks
>
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/array_helper.hpp
View file @
4d70c71b
#ifndef CK_ARRAY_HELPER_HPP
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "array.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "array_element_picker.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
Arr
,
typename
Picks
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
p
ick_array_element
(
Arr
&
a
,
Picks
)
__host__
__device__
constexpr
auto
p
ush_back
(
const
Array
<
TData
,
NSize
>&
a
,
const
TData
&
x
)
{
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
Array
<
TData
,
NSize
+
1
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
r
,
&
a
](
auto
i
)
constexpr
{
r
(
i
)
=
a
[
i
];
});
r
(
Number
<
NSize
>
{})
=
x
;
return
r
;
}
}
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
...
@@ -63,20 +74,6 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
...
@@ -63,20 +74,6 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return
new_array
;
return
new_array
;
}
}
template
<
typename
TData
,
index_t
NSize
,
typename
ExtractSeq
>
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
{
Array
<
TData
,
ExtractSeq
::
GetSize
()
>
new_array
;
constexpr
index_t
new_size
=
ExtractSeq
::
GetSize
();
static_assert
(
new_size
<=
NSize
,
"wrong! too many extract"
);
static_for
<
0
,
new_size
,
1
>
{}([
&
](
auto
I
)
{
new_array
(
I
)
=
old_array
[
ExtractSeq
::
At
(
I
)];
});
return
new_array
;
}
// emulate constepxr lambda for array
// emulate constepxr lambda for array
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
struct
lambda_array_math
struct
lambda_array_math
...
@@ -201,31 +198,25 @@ reverse_exclusive_scan_on_array(const Array<TData, NSize>& x, Reduce f, TData in
...
@@ -201,31 +198,25 @@ reverse_exclusive_scan_on_array(const Array<TData, NSize>& x, Reduce f, TData in
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
merge_arrays
(
const
X
&
x
,
const
Ys
&
...
ys
)
__host__
__device__
constexpr
auto
container_cat
(
const
X
&
x
,
const
Ys
&
...
ys
)
{
{
return
merge_arrays
(
x
,
merge_arrays
(
ys
...));
return
container_cat
(
x
,
container_cat
(
ys
...));
}
}
template
<
typename
T
,
index_t
NX
,
index_t
NY
>
template
<
typename
T
,
index_t
NX
,
index_t
NY
>
__host__
__device__
constexpr
auto
merge_arrays
(
const
Array
<
T
,
NX
>&
x
,
const
Array
<
T
,
NY
>&
y
)
__host__
__device__
constexpr
auto
container_cat
(
const
Array
<
T
,
NX
>&
x
,
const
Array
<
T
,
NY
>&
y
)
{
{
Array
<
T
,
NX
+
NY
>
z
;
Array
<
T
,
NX
+
NY
>
z
;
for
(
index_t
i
=
0
;
i
<
NX
;
++
i
)
static_for
<
0
,
NX
,
1
>
{}([
&
](
auto
i
)
{
z
(
i
)
=
x
[
i
];
});
{
z
(
i
)
=
x
[
i
];
}
for
(
index_t
i
=
0
;
i
<
NY
;
++
i
)
static_for
<
0
,
NY
,
1
>
{}([
&
](
auto
i
)
{
z
(
i
+
Number
<
NX
>
{})
=
y
[
i
];
});
{
z
(
i
+
NX
)
=
y
[
i
];
}
return
z
;
return
z
;
}
}
template
<
typename
X
>
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
merge_arrays
(
const
X
&
x
)
__host__
__device__
constexpr
auto
container_cat
(
const
Array
<
T
,
N
>
&
x
)
{
{
return
x
;
return
x
;
}
}
...
...
composable_kernel/include/utility/print.hpp
View file @
4d70c71b
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CK_PRINT_HPP
#define CK_PRINT_HPP
#include "array.hpp"
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_helper.hpp"
#include "array_helper.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
...
@@ -19,7 +20,7 @@ __host__ __device__ void print_array(const char* s, T a)
...
@@ -19,7 +20,7 @@ __host__ __device__ void print_array(const char* s, T a)
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%u, "
,
a
[
i
]);
});
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%u, "
,
a
[
i
]);
});
printf
(
"}
\n
"
);
printf
(
"}
\n
"
);
}
}
else
if
constexpr
(
is_same
<
data_type
,
int32_t
>
{}
)
else
if
constexpr
(
true
)
{
{
printf
(
"%s size %d, {"
,
s
,
nsize
);
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%d, "
,
a
[
i
]);
});
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%d, "
,
a
[
i
]);
});
...
@@ -39,7 +40,7 @@ __host__ __device__ void print_array_v2(const char* s, T a)
...
@@ -39,7 +40,7 @@ __host__ __device__ void print_array_v2(const char* s, T a)
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%u] %u, "
,
i
.
value
,
a
[
i
]);
});
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%u] %u, "
,
i
.
value
,
a
[
i
]);
});
printf
(
"}
\n
"
);
printf
(
"}
\n
"
);
}
}
else
if
constexpr
(
is_same
<
data_type
,
int32_t
>
{}
)
else
if
constexpr
(
true
)
{
{
printf
(
"%s size %d, {"
,
s
,
nsize
);
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%d] %d, "
,
i
.
value
,
a
[
i
]);
});
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%d] %d, "
,
i
.
value
,
a
[
i
]);
});
...
...
driver/include/device_dummy_dynamic_transform_v1.hpp
View file @
4d70c71b
...
@@ -28,17 +28,17 @@ void device_dummy_dynamic_transform_v1(InDesc,
...
@@ -28,17 +28,17 @@ void device_dummy_dynamic_transform_v1(InDesc,
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
const
auto
in_nchw_desc
=
make_dynamic_native_tensor_descriptor
(
to_array
(
InDesc
::
GetLengths
()),
const
auto
in_nchw_desc
=
make_dynamic_native_tensor_descriptor
(
to_array
(
InDesc
::
GetStrides
()));
to_multi_index
(
InDesc
::
GetLengths
()),
to_multi_index
(
InDesc
::
GetStrides
()));
const
auto
wei_kcyx_desc
=
make_dynamic_native_tensor_descriptor
(
const
auto
wei_kcyx_desc
=
make_dynamic_native_tensor_descriptor
(
to_
array
(
WeiDesc
::
GetLengths
()),
to_
array
(
WeiDesc
::
GetStrides
()));
to_
multi_index
(
WeiDesc
::
GetLengths
()),
to_
multi_index
(
WeiDesc
::
GetStrides
()));
const
auto
out_nkhw_desc
=
make_dynamic_native_tensor_descriptor
(
const
auto
out_nkhw_desc
=
make_dynamic_native_tensor_descriptor
(
to_
array
(
OutDesc
::
GetLengths
()),
to_
array
(
OutDesc
::
GetStrides
()));
to_
multi_index
(
OutDesc
::
GetLengths
()),
to_
multi_index
(
OutDesc
::
GetStrides
()));
const
auto
conv_strides
=
to_
array
(
ConvStrides
{});
const
auto
conv_strides
=
to_
multi_index
(
ConvStrides
{});
const
auto
conv_dilations
=
to_
array
(
ConvDilations
{});
const
auto
conv_dilations
=
to_
multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_
array
(
InLeftPads
{});
const
auto
in_left_pads
=
to_
multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_
array
(
InRightPads
{});
const
auto
in_right_pads
=
to_
multi_index
(
InRightPads
{});
{
{
const
auto
tensor_descs
=
map_convolution_into_gemm_v1
(
wei_kcyx_desc
,
const
auto
tensor_descs
=
map_convolution_into_gemm_v1
(
wei_kcyx_desc
,
...
...
driver/include/device_dummy_dynamic_transform_v2.hpp
View file @
4d70c71b
...
@@ -58,10 +58,13 @@ void device_dummy_dynamic_transform_v2(InDesc,
...
@@ -58,10 +58,13 @@ void device_dummy_dynamic_transform_v2(InDesc,
const
auto
in_gemmk_gemmn_coord_step
=
make_dynamic_tensor_coordinate_step_v2
(
const
auto
in_gemmk_gemmn_coord_step
=
make_dynamic_tensor_coordinate_step_v2
(
in_gemmk_gemmn_global_desc
,
make_multi_index
(
1
,
0
));
in_gemmk_gemmn_global_desc
,
make_multi_index
(
1
,
0
));
print_array
(
"do_tansforms: "
,
in_gemmk_gemmn_coord_step
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
{
printf
(
"iter %d
\n
"
,
iter
);
printf
(
"iter %d
\n
"
,
iter
);
print_array
(
"idx: "
,
in_gemmk_gemmn_coord
.
GetIndex
());
print_array
(
"idx: "
,
in_gemmk_gemmn_coord
.
GetIndex
());
print_array
(
"hidden idx: "
,
in_gemmk_gemmn_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_coord
.
GetOffset
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_coord
.
GetOffset
());
printf
(
"
\n
"
);
printf
(
"
\n
"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment