Commit 2cbb8976 authored by Jing Zhang's avatar Jing Zhang
Browse files

add static_buffer_v2 zero out

parent d798c9b8
...@@ -518,6 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -518,6 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// main body // main body
index_t k0_block_data_begin = 0; index_t k0_block_data_begin = 0;
c_thread_buf.Clear();
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
do do
......
...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
static constexpr index_t vector_size = GetVectorSize(); static constexpr index_t vector_size = GetVectorSize();
__host__ __device__ static constexpr index_t GetNumVectors() { return N; }
__host__ __device__ static constexpr index_t GetNumElements()
{
return GetVectorSize() * GetNumVectors();
}
VecBaseType invalid_element_value_ = VecBaseType{0}; VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0}; T invalid_vec_value_ = T{0};
...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
return GetElement(i, true); return GetElement(i, true);
} }
__host__ __device__ void Clear()
{
static_for<0, GetNumElements(), 1>{}(
[&](auto i) { GetElement(i, true) = invalid_element_value_; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment