Commit 1d536d7d authored by Tri Dao's avatar Tri Dao
Browse files

Minor cleanup of softcapping

parent beb2bf2a
...@@ -106,9 +106,9 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -106,9 +106,9 @@ void set_params_fprop(Flash_fwd_params &params,
#endif #endif
if (softcap > 0.0) { if (softcap > 0.0) {
params.softcap = softmax_scale / softcap; params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap; params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E; params.scale_softmax_log2 = softcap * M_LOG2E;
}else{ } else{
// Remove potential NaN // Remove potential NaN
params.softcap = 0.0; params.softcap = 0.0;
params.scale_softmax = softmax_scale; params.scale_softmax = softmax_scale;
......
...@@ -24,17 +24,9 @@ using namespace cute; ...@@ -24,17 +24,9 @@ using namespace cute;
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ __forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor))::value == 4, "First dimension must be 4");
#pragma unroll #pragma unroll
for (int i=0; i < size<0>(tensor); ++i){ // MMA for (int i = 0; i < size(tensor); ++i) {
#pragma unroll tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
for (int mi=0; mi < size<1>(tensor); ++mi){
#pragma unroll
for (int nj=0; nj < size<2>(tensor); ++nj){
tensor(i, mi, nj) = cutlass::fast_tanh(tensor(i, mi, nj) * softcap );
}
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment