"tests/vscode:/vscode.git/clone" did not exist on "b3f128609c991574481b02a987c0ad8f2b768369"
Commit 6a89b2f1 authored by Tri Dao's avatar Tri Dao
Browse files

Remove constexpr in launch template to fix CI compilation

parent 97ba7a62
Subproject commit 3a8f57a3c89cfff7aa686e95f13d9ad850f61898
Subproject commit 34fd98056b69fbf7f0929b3f734bb5f00642e2c9
......@@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
......@@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
});
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
}
......@@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
......@@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
});
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
if constexpr(Kernel_traits::kSmemKVSize >= 48 * 1024) {
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
}
......
......@@ -46,7 +46,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if constexpr(smem_size >= 48 * 1024) {
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
......@@ -74,7 +74,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if constexpr(smem_size >= 48 * 1024) {
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
......
__version__ = "2.1.2.post1"
__version__ = "2.1.2.post2"
from flash_attn.flash_attn_interface import (
flash_attn_func,
......
......@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.1.2.post1
RUN pip install flash-attn==2.1.2.post2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.1.2.post1 \
&& cd flash-attention && git checkout v2.1.2.post2 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/xentropy && pip install . && cd ../../ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment