Commit b69499b9 authored by danyao12's avatar danyao12
Browse files

fix fwd dropout

parent 0d93f4a0
...@@ -436,7 +436,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -436,7 +436,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels # no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'no', squant, mask))
else: else:
assert False assert False
return pipelines return pipelines
......
...@@ -551,7 +551,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -551,7 +551,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels # no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'no', squant, mask))
else: else:
assert False assert False
return pipelines return pipelines
......
...@@ -51,7 +51,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -51,7 +51,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
......
...@@ -718,7 +718,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -718,7 +718,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::FmhaDropout::IsDropout)
{ {
constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
constexpr auto config = constexpr auto config =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment