Commit 1b15c746 authored by wuyf1's avatar wuyf1 Committed by wenjh
Browse files

[DCU] fix 48 FA fails,thread overflow and norm_mlp.


Signed-off-by: default avatarWuyufan <Wuyf1@sugon.com>
解决了如下问题:
1. FA 48个单测报错问题(fA输入:[B,S,H,D] reshape to [blocknums,blocksize,H,D]).
2. 在解决FA 报错的时候一个kernel launch error.
3. norm_mlp问题暂时用rest_rng_state解决

See merge request dcutoolkit/deeplearing/TransformerEngine!77
Co-authored-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
Co-authored-by: default avatarwuyufffan <1095978552@qq.com>
parent 9df0c4a3
......@@ -1645,6 +1645,9 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute(
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
# Reset RNG state at test start to ensure deterministic model initialization
reset_rng_states()
config = model_configs[model]
te_ln_mlp = TestReturnBiasModule(
......
......@@ -13,7 +13,7 @@ namespace kv_cache {
constexpr int block_size = 1024;
template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
__global__ __launch_bounds__(1024) void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
......@@ -53,7 +53,7 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat
}
template <typename dtype>
__global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
__global__ __launch_bounds__(1024) void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
int *page_table, int *cu_new_lens, int *cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int b, int max_ctx_len, int max_seq_len,
......
......@@ -683,6 +683,8 @@ class FlashAttention(torch.nn.Module):
https://github.com/Dao-AILab/flash-attention
"""
_page64_offsets_cache: Dict[Tuple[str, int, torch.dtype], torch.Tensor] = {}
def __init__(
self,
softmax_scale: float,
......@@ -715,6 +717,56 @@ class FlashAttention(torch.nn.Module):
if not self.logger.hasHandlers():
self.logger.addHandler(attn_log._stream_handler)
@classmethod
def _get_cached_page_offsets(
cls, split_factor: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
device_key = str(device)
cache_key = (device_key, split_factor, dtype)
offsets = cls._page64_offsets_cache.get(cache_key)
if offsets is None:
offsets = torch.arange(split_factor, dtype=dtype, device=device)
cls._page64_offsets_cache[cache_key] = offsets
return offsets
@classmethod
def _remap_kv_cache_to_page64(
cls,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
block_table: torch.Tensor,
allow_negative_entries: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
page_size = 64
old_page_size = key_layer.shape[1]
if old_page_size % page_size != 0:
raise ValueError(
f"KV cache page_size={old_page_size} must be a multiple of {page_size}"
)
split_factor = old_page_size // page_size
key_layer = key_layer.reshape(-1, page_size, key_layer.shape[2], key_layer.shape[3])
value_layer = value_layer.reshape(-1, page_size, value_layer.shape[2], value_layer.shape[3])
table_dtype = torch.int32
block_table = block_table.to(device=key_layer.device, dtype=table_dtype)
offsets = cls._get_cached_page_offsets(split_factor, block_table.device, table_dtype)
expanded_block_table = block_table.unsqueeze(-1) * split_factor + offsets.view(1, 1, -1)
if allow_negative_entries:
invalid_mask = block_table.unsqueeze(-1) < 0
expanded_block_table = torch.where(
invalid_mask,
torch.full_like(expanded_block_table, -1),
expanded_block_table,
)
expanded_block_table = expanded_block_table.reshape(
block_table.shape[0], block_table.shape[1] * split_factor
)
return key_layer, value_layer, expanded_block_table
def forward(
self,
query_layer: torch.Tensor,
......@@ -995,14 +1047,28 @@ class FlashAttention(torch.nn.Module):
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if inference_params is not None:
# use block_table kwarg to support thd_2bshd for non-paged
fa_optional_forward_kwargs["block_table"] = (
inference_params.cache_manager.page_table[:batch_size]
if inference_params.is_paged
else inference_params.cache_manager.batch_indices_post_step.unsqueeze(
1
)[:batch_size]
)
if inference_params.is_paged:
page_table = inference_params.cache_manager.page_table[:batch_size]
key_layer, value_layer, remapped_block_table = (
self._remap_kv_cache_to_page64(
key_layer,
value_layer,
page_table,
allow_negative_entries=True,
)
)
fa_optional_forward_kwargs["block_table"] = remapped_block_table
else:
base = inference_params.cache_manager.batch_indices_post_step[
:batch_size
].unsqueeze(1)
key_layer, value_layer, remapped_block_table = self._remap_kv_cache_to_page64(
key_layer,
value_layer,
base,
allow_negative_entries=False,
)
fa_optional_forward_kwargs["block_table"] = remapped_block_table
output = func(
query_layer,
key_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment