logger.info("Using decode attention configuration from %s for attention layer.",config_file_path)
# If a configuration has been found, return it
returnjson.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ",config_file_path)
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ",config_file_path)
# If a configuration has been found, return it
returnjson.load(f)
else:
raiseValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
returnNone
classTritonMLABackend(AttentionBackend):
classTritonMLABackend(AttentionBackend):
...
@@ -682,7 +735,7 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -682,7 +735,7 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
"TritonMLAImpl")
"TritonMLAImpl")
def_forward_prefill(
def_forward_prefill(
self,
self,
q:torch.Tensor,
q:torch.Tensor,
...
@@ -735,12 +788,14 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -735,12 +788,14 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):