Commit 15baccf2 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 5029a5a4 a328df25
......@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_;
};
......
......@@ -36,6 +36,8 @@ struct tensor_view
{
}
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
......@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
}
// X is vector of DataType.
......
This diff is collapsed.
......@@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false;
else
os << delim;
os << static_cast<T>(v);
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
{
os << ck::type_convert<float>(v);
}
else
{
os << static_cast<T>(v);
}
}
return os;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment