Commit cb2d4dbb authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-bwd-dropout' into attn-fwd-train-dropout

parents 989e3d10 0e7aeef5
...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, index_t N>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Number<N>& y)
{
constexpr index_t NSize = sizeof...(Xs);
// Tuple<Xs...> r;
// static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y; });
// return r;
return generate_tuple([&](auto i) { return x[i] * y; }, Number<NSize>{});
}
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
......
...@@ -19,4 +19,37 @@ struct ThisThreadBlock ...@@ -19,4 +19,37 @@ struct ThisThreadBlock
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
}; };
template <index_t ThreadPerBlock>
struct SubThreadBlock
{
static constexpr index_t kNumThread_ = ThreadPerBlock;
__device__ SubThreadBlock(int mwave, int nwave) : mwave_(mwave), nwave_(nwave) {}
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename TupleArg1, typename TupleArg2>
__device__ constexpr bool IsBelong(const TupleArg1& mwave_range, const TupleArg2& nwave_range)
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0])
return false;
else if(mwave_ >= mwave_range[I1])
return false;
else if(nwave_ < nwave_range[I0])
return false;
else if(nwave_ >= nwave_range[I1])
return false;
else
return true;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
private:
index_t mwave_, nwave_;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
};
} // namespace ck } // namespace ck
...@@ -58,3 +58,4 @@ add_subdirectory(batchnorm) ...@@ -58,3 +58,4 @@ add_subdirectory(batchnorm)
if(GPU_TARGETS MATCHES "gfx1100") if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
add_subdirectory(host_tensor)
add_gtest_executable(test_host_tensor test_host_tensor.cpp)
target_link_libraries(test_host_tensor PRIVATE utility)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment