"vscode:/vscode.git/clone" did not exist on "c8c3839cb9b0fc52885cafa431460f686c46f5ae"
Unverified Commit f74c2d1d authored by MuWinds's avatar MuWinds Committed by GitHub
Browse files

Solve `torch.backends.cuda.sdp_kernel()` is deprecated.

parent 1548c992
...@@ -13,6 +13,7 @@ from transformers import ( ...@@ -13,6 +13,7 @@ from transformers import (
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
from torch.nn.attention import SDPBackend
import torch import torch
import sys, os import sys, os
from ..base import ThreadContext, BackendInterfaceBase from ..base import ThreadContext, BackendInterfaceBase
...@@ -292,7 +293,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -292,7 +293,7 @@ class TransformersInterface(BackendInterfaceBase):
def generate(self): def generate(self):
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
for _ in range(1, self.args.max_new_tokens): for _ in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc("decode") self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id: if next_token == self.tokenizer.eos_token_id:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment