Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
65a00640
Commit
65a00640
authored
Apr 30, 2021
by
Chao Liu
Browse files
fix bug in tensor adaptor
parent
fc148cef
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
253 additions
and
924 deletions
+253
-924
composable_kernel/include/tensor_description/cluster_descriptor.hpp
..._kernel/include/tensor_description/cluster_descriptor.hpp
+3
-1
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+87
-2
composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp
...nsor_description/dynamic_multi_index_transform_helper.hpp
+7
-0
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+44
-20
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+96
-437
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+16
-464
No files found.
composable_kernel/include/tensor_description/cluster_descriptor.hpp
View file @
65a00640
...
@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
...
@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return
ClusterDescriptor
<
Lengths
,
decltype
(
order
)
>
{};
return
ClusterDescriptor
<
Lengths
,
decltype
(
order
)
>
{};
}
}
#if 1
template
<
typename
Lengths
,
template
<
typename
Lengths
,
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
__host__
__device__
constexpr
auto
make_cluster_descriptor_v2
(
__host__
__device__
constexpr
auto
make_cluster_descriptor_v2
(
...
@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
...
@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
constexpr
auto
up_dim_new_top_ids
=
Sequence
<
0
>
{};
constexpr
auto
up_dim_new_top_ids
=
Sequence
<
0
>
{};
return
make_si
mpl
e_tensor_adaptor
(
return
make_si
ngle_stag
e_tensor_adaptor
(
make_tuple
(
transform
),
make_tuple
(
low_dim_old_top_ids
),
make_tuple
(
up_dim_new_top_ids
));
make_tuple
(
transform
),
make_tuple
(
low_dim_old_top_ids
),
make_tuple
(
up_dim_new_top_ids
));
}
}
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
65a00640
...
@@ -1282,7 +1282,7 @@ struct DynamicFreeze
...
@@ -1282,7 +1282,7 @@ struct DynamicFreeze
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
const
UpIdx
&
idx_up
)
const
{
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
0
,
"wrong! inconsistent # of dimension"
);
"wrong! inconsistent # of dimension"
);
idx_low
=
low_idx_
;
idx_low
=
low_idx_
;
...
@@ -1299,7 +1299,7 @@ struct DynamicFreeze
...
@@ -1299,7 +1299,7 @@ struct DynamicFreeze
const
UpIdx
&
idx_up_new
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
Number
<
Hack
>
)
{
{
idx_diff_low
(
Number
<
0
>
{})
=
index_t
{
Number
<
0
>
{}}
;
idx_diff_low
(
Number
<
0
>
{})
=
0
;
}
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
@@ -1328,5 +1328,90 @@ struct DynamicFreeze
...
@@ -1328,5 +1328,90 @@ struct DynamicFreeze
}
}
};
};
template
<
typename
VectorSize
,
typename
UpLength
>
struct
DynamicVectorize
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
VectorSize
vector_size_
;
__host__
__device__
constexpr
DynamicVectorize
()
=
default
;
__host__
__device__
constexpr
DynamicVectorize
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
:
vector_size_
{
vector_size
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
vector_size_
*
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
vector_size_
*
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"DynamicVectorize, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp
View file @
65a00640
...
@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
...
@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return
DynamicFreeze
<
LowerIndex
>
{
low_idx
};
return
DynamicFreeze
<
LowerIndex
>
{
low_idx
};
}
}
template
<
typename
VectorSize
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_vectorize_transform
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
{
return
DynamicVectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
65a00640
...
@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr
index_t
ndim_low
=
constexpr
index_t
ndim_low
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
adaptor1_min_hidden_id
=
constexpr
index_t
low_dim_hidden_id
=
math
::
min
(
adaptor1_min_hidden_id
,
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
;
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
);
bool
is_bottom_dim
=
false
;
static_for
<
0
,
TensorAdaptor1
::
GetNumOfBottomDimension
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
low_dim_hidden_id
==
TensorAdaptor1
::
GetBottomDimensionHiddenIds
()[
i
])
{
is_bottom_dim
=
true
;
}
});
if
(
!
is_bottom_dim
)
{
adaptor1_min_hidden_id
=
math
::
min
(
adaptor1_min_hidden_id
,
low_dim_hidden_id
);
}
});
});
constexpr
index_t
ndim_up
=
constexpr
index_t
ndim_up
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
// get the min of all upper dimensions
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor1_min_hidden_id
=
adaptor1_min_hidden_id
=
math
::
min
(
adaptor1_min_hidden_id
,
math
::
min
(
adaptor1_min_hidden_id
,
...
@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}();
}();
constexpr
index_t
adaptor1_hidden_id_shift
=
constexpr
index_t
adaptor1_hidden_id_shift
=
adaptor
1
_m
in
_hidden_id
-
adaptor
0
_m
ax
_hidden_id
+
1
;
adaptor
0
_m
ax
_hidden_id
+
1
-
adaptor
1
_m
in
_hidden_id
;
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
GetNumOfBottomDimension
();
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
GetNumOfBottomDimension
();
...
@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id so every dim id is unique
// shift hidden id so every dim id is unique
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
low_dim_hidden_ids_1_mod
(
idim_low_1
)
-
=
adaptor1_hidden_id_shift
;
low_dim_hidden_ids_1_mod
(
idim_low_1
)
+
=
adaptor1_hidden_id_shift
;
});
});
// match hidden id
// match hidden id
...
@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id
// shift hidden id
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
up_dim_hidden_ids_1_mod
(
idim_up_1
)
-
=
adaptor1_hidden_id_shift
;
up_dim_hidden_ids_1_mod
(
idim_up_1
)
+
=
adaptor1_hidden_id_shift
;
});
});
return
up_dim_hidden_ids_1_mod
;
return
up_dim_hidden_ids_1_mod
;
...
@@ -344,21 +360,21 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -344,21 +360,21 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr
auto
top_dim_hidden_ids
=
constexpr
auto
top_dim_hidden_ids
=
TensorAdaptor1
::
GetTopDimensionHiddenIds
()
-
Number
<
adaptor1_hidden_id_shift
>
{};
TensorAdaptor1
::
GetTopDimensionHiddenIds
()
+
Number
<
adaptor1_hidden_id_shift
>
{};
// put everything together
// put everything together
return
TensorAdaptor
<
decltype
(
all_transforms
),
return
TensorAdaptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
decltype
(
all_low_dim_hidden_idss
),
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
decltype
(
all_up_dim_hidden_idss
),
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
decltype
(
bottom_dim_hidden_ids
),
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
decltype
(
top_dim_hidden_ids
)
>
{
all_transforms
};
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>
>
{
all_transforms
};
}
}
// Transforms: Tuple<transforms...>
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
__host__
__device__
constexpr
auto
make_si
mpl
e_tensor_adaptor
(
const
Transforms
&
transforms
,
__host__
__device__
constexpr
auto
make_si
ngle_stag
e_tensor_adaptor
(
const
Transforms
&
transforms
,
LowerDimensionOldTopIdss
,
LowerDimensionOldTopIdss
,
UpperDimensionNewTopIdss
)
UpperDimensionNewTopIdss
)
{
{
...
@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
...
@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
constexpr
auto
top_dim_hidden_ids
=
constexpr
auto
top_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
Number
<
ndim_old_top
>
{};
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
Number
<
ndim_old_top
>
{};
return
TensorAdaptor
<
Transforms
,
return
TensorAdaptor
<
remove_cv_t
<
Transforms
>
,
decltype
(
low_dim_hidden_idss
),
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
decltype
(
up_dim_hidden_idss
),
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
decltype
(
bottom_dim_hidden_ids
),
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
decltype
(
top_dim_hidden_ids
)
>
{
transforms
};
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
template
<
typename
X
,
typename
...
Xs
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
65a00640
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
65a00640
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment