"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "5266dea027a9e8b230799dbfc44c389e6dfc72a8"
Commit c984208d authored by Tri Dao's avatar Tri Dao
Browse files

Set block size to 64 x 64 for kvcache to avoid nvcc segfaults

parent 8c8b4d36
...@@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim> template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
constexpr int kBlockM = 64; // Fixed for all head dimensions constexpr int kBlockM = 64; // Fixed for all head dimensions
if (!is_sm8x) { // A100, H100 // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128.
// and for headdim 192 with block size 64 x 128. // Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream); run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
} }
template<typename T> template<typename T>
......
__version__ = "2.2.3" __version__ = "2.2.3.post1"
from flash_attn.flash_attn_interface import ( from flash_attn.flash_attn_interface import (
flash_attn_func, flash_attn_func,
......
...@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr ...@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention # Install FlashAttention
RUN pip install flash-attn==2.2.3 RUN pip install flash-attn==2.2.3.post1
# Install CUDA extensions for cross-entropy, fused dense, layer norm # Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \ RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.2.3 \ && cd flash-attention && git checkout v2.2.3.post1 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/layer_norm && pip install . && cd ../../ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment