"components/vscode:/vscode.git/clone" did not exist on "e5b6a054b51bf21658919037e8580133cbdf3fae"
Unverified Commit 363aaeef authored by Mohammad Othman's avatar Mohammad Othman Committed by GitHub
Browse files

Fix IntermediateTensors initialization and add type hints (#28743)


Signed-off-by: default avatarMohammad Othman <Mo@MohammadOthman.com>
Co-authored-by: default avatarMohammad Othman <Mo@MohammadOthman.com>
parent ac86bff8
...@@ -60,12 +60,17 @@ class IntermediateTensors: ...@@ -60,12 +60,17 @@ class IntermediateTensors:
tensors: dict[str, torch.Tensor] tensors: dict[str, torch.Tensor]
kv_connector_output: KVConnectorOutput | None kv_connector_output: KVConnectorOutput | None
def __init__(self, tensors): def __init__(
self,
tensors: dict[str, torch.Tensor],
kv_connector_output: KVConnectorOutput | None = None,
) -> None:
# manually define this function, so that # manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file. # Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating # Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file. # a string, and we will lose the information about the source file.
self.tensors = tensors self.tensors = tensors
self.kv_connector_output = kv_connector_output
def __getitem__(self, key: str | slice): def __getitem__(self, key: str | slice):
if isinstance(key, str): if isinstance(key, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment