Commit f5287e90 authored by qiyuxinlin's avatar qiyuxinlin
Browse files

fix no balance_serve import error

parent 03a65d6b
...@@ -12,7 +12,10 @@ import torch.nn as nn ...@@ -12,7 +12,10 @@ import torch.nn as nn
import transformers import transformers
from transformers import Cache, PretrainedConfig from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
class StaticCache(transformers.StaticCache): class StaticCache(transformers.StaticCache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)`.
...@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module): ...@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = [] self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext): def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
self.k_caches.append( self.k_caches.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment