"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "97fd9948f163c0f81390fe80a14d660409013c75"
Unverified Commit b2aa2317 authored by julienmancuso's avatar julienmancuso Committed by GitHub
Browse files

feat: deploy planner in operator (#921)


Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 57975b27
......@@ -13,9 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["CircusController", "LocalConnector", "PlannerConnector"]
__all__ = [
"CircusController",
"LocalConnector",
"PlannerConnector",
"KubernetesConnector",
]
# Import the classes
from dynamo.planner.circusd import CircusController
from dynamo.planner.kubernetes_connector import KubernetesConnector
from dynamo.planner.local_connector import LocalConnector
from dynamo.planner.planner_connector import PlannerConnector
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from kubernetes import client, config
class KubernetesAPI:
def __init__(self):
# Load kubernetes configuration
config.load_incluster_config() # for in-cluster deployment
self.custom_api = client.CustomObjectsApi()
self.current_namespace = self._get_current_namespace()
def _get_current_namespace(self) -> str:
"""Get the current namespace if running inside a k8s cluster"""
try:
with open(
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r"
) as f:
return f.read().strip()
except FileNotFoundError:
# Fallback to 'default' if not running in k8s
return "default"
async def get_graph_deployment(
self, component_name: str, dynamo_namespace: str
) -> Optional[dict]:
"""
Get DynamoGraphDeployment by first finding the associated DynamoComponentDeployment
and then retrieving its owner reference.
Args:
component_name: The name of the component
dynamo_namespace: The dynamo namespace
Returns:
The DynamoGraphDeployment object or None if not found
"""
try:
# First, find the DynamoComponentDeployment using the component name and namespace labels
label_selector = f"nvidia.com/dynamo-component={component_name},nvidia.com/dynamo-namespace={dynamo_namespace}"
component_deployments = self.custom_api.list_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamocomponentdeployments",
label_selector=label_selector,
)
items = component_deployments.get("items", [])
if not items:
return None
if len(items) > 1:
raise ValueError(
f"Multiple component deployments found for component {component_name} in dynamo namespace {dynamo_namespace}. "
"Expected exactly one deployment."
)
# Get the component deployment and extract the owner reference
component_deployment = items[0]
owner_refs = component_deployment.get("metadata", {}).get(
"ownerReferences", []
)
# Find the DynamoGraphDeployment in the owner references
graph_deployment_ref = None
for ref in owner_refs:
if (
ref.get("apiVersion") == "nvidia.com/v1alpha1"
and ref.get("kind") == "DynamoGraphDeployment"
):
graph_deployment_ref = ref
break
if not graph_deployment_ref:
return None
# Get the actual DynamoGraphDeployment using the name from the owner reference
graph_deployment_name = graph_deployment_ref.get("name")
if not graph_deployment_name:
return None
graph_deployment = self.custom_api.get_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamographdeployments",
name=graph_deployment_name,
)
return graph_deployment
except client.ApiException as e:
if e.status == 404:
return None
raise
async def update_graph_replicas(
self, graph_deployment_name: str, component_name: str, replicas: int
) -> None:
"""Update the replicas count for a component in a DynamoGraphDeployment"""
patch = {"spec": {"services": {component_name: {"replicas": replicas}}}}
self.custom_api.patch_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamographdeployments",
name=graph_deployment_name,
body=patch,
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .kube import KubernetesAPI
from .planner_connector import PlannerConnector
class KubernetesConnector(PlannerConnector):
def __init__(self, namespace: str):
self.kube_api = KubernetesAPI()
self.namespace = namespace
async def add_component(self, component_name: str):
"""Add a component by increasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace
)
if deployment is None:
raise ValueError(
f"Graph not found for component {component_name} in dynamo namespace {self.namespace}"
)
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas + 1,
)
async def remove_component(self, component_name: str):
"""Remove a component by decreasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace
)
if deployment is None:
raise ValueError(
f"Graph {component_name} not found for namespace {self.namespace}"
)
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
if current_replicas > 0:
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas - 1,
)
def _get_current_replicas(self, deployment: dict, component_name: str) -> int:
"""Get the current replicas for a component in a graph deployment"""
return (
deployment.get("spec", {})
.get("services", {})
.get(component_name, {})
.get("replicas", 1)
)
def _get_graph_deployment_name(self, deployment: dict) -> str:
"""Get the name of the graph deployment"""
return deployment["metadata"]["name"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock, Mock
import pytest
from dynamo.planner.kubernetes_connector import KubernetesConnector
@pytest.fixture
def mock_kube_api():
mock_api = Mock()
mock_api.get_graph_deployment = AsyncMock()
mock_api.update_graph_replicas = AsyncMock()
return mock_api
@pytest.fixture
def mock_kube_api_class(mock_kube_api):
mock_class = Mock()
mock_class.return_value = mock_kube_api
return mock_class
@pytest.fixture
def kubernetes_connector(mock_kube_api_class, monkeypatch):
# Patch the KubernetesAPI class before instantiating the connector
monkeypatch.setattr(
"dynamo.planner.kubernetes_connector.KubernetesAPI", mock_kube_api_class
)
connector = KubernetesConnector()
# Set the namespace attribute that's being accessed in the error
connector.namespace = "default"
return connector
@pytest.mark.asyncio
async def test_add_component_increases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 1}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.add_component(component_name)
# Assert
mock_kube_api.get_graph_deployment.assert_called_once_with(component_name)
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2
)
@pytest.mark.asyncio
async def test_add_component_with_no_replicas_specified(
kubernetes_connector, mock_kube_api
):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.add_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2
)
@pytest.mark.asyncio
async def test_add_component_deployment_not_found(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_kube_api.get_graph_deployment.return_value = None
# Act & Assert
with pytest.raises(
ValueError, match=f"Graph not found for component {component_name}"
):
await kubernetes_connector.add_component(component_name)
@pytest.mark.asyncio
async def test_remove_component_decreases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 2}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 1
)
@pytest.mark.asyncio
async def test_remove_component_with_zero_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 0}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_not_called()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: planner-serviceaccount
namespace: {{ .Values.namespace }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: planner-role
namespace: {{ .Values.namespace }}
rules:
- apiGroups: ["nvidia.com"]
resources: ["dynamocomponentdeployments", "dynamographdeployments"]
verbs: ["get", "list", "create", "update", "patch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: planner-binding
namespace: {{ .Values.namespace }}
subjects:
- kind: ServiceAccount
name: planner-serviceaccount
namespace: {{ .Values.namespace }}
roleRef:
kind: Role
name: planner-role
apiGroup: rbac.authorization.k8s.io
\ No newline at end of file
......@@ -182,7 +182,6 @@ func main() {
}
if err = (&controller.DynamoComponentDeploymentReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("dynamocomponentdeployment"),
Config: ctrlConfig,
NatsAddr: natsAddr,
......@@ -204,7 +203,6 @@ func main() {
}
if err = (&controller.DynamoGraphDeploymentReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("dynamographdeployment"),
Config: ctrlConfig,
VirtualServiceGateway: istioVirtualServiceGateway,
......
......@@ -9,7 +9,6 @@ require (
emperror.dev/errors v0.8.1
github.com/apparentlymart/go-shquot v0.0.1
github.com/bsm/gomega v1.27.10
github.com/cisco-open/k8s-objectmatcher v1.9.0
github.com/huandu/xstrings v1.4.0
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/onsi/ginkgo/v2 v2.19.0
......@@ -95,7 +94,6 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/grpc v1.65.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
......
......@@ -10,8 +10,6 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cisco-open/k8s-objectmatcher v1.9.0 h1:/sfuO0BD09fpynZjXsqeZrh28Juc4VEwc2P6Ov/Q6fM=
github.com/cisco-open/k8s-objectmatcher v1.9.0/go.mod h1:CH4E6qAK+q+JwKFJn0DaTNqxrbmWCaDQzGthKLK4nZ0=
github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4=
github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
......
......@@ -23,7 +23,6 @@ import (
"context"
"fmt"
"os"
"reflect"
"sort"
"strconv"
"strings"
......@@ -43,14 +42,12 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/cisco-open/k8s-objectmatcher/patch"
"github.com/huandu/xstrings"
istioNetworking "istio.io/api/networking/v1beta1"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/tools/record"
......@@ -58,7 +55,6 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/log"
......@@ -77,12 +73,12 @@ const (
DeploymentTargetTypeDebug = "debug"
HeaderNameDebug = "X-Nvidia-Debug"
DefaultIngressSuffix = "local"
KubernetesDeploymentStrategy = "kubernetes"
)
// DynamoComponentDeploymentReconciler reconciles a DynamoComponentDeployment object
type DynamoComponentDeploymentReconciler struct {
client.Client
Scheme *runtime.Scheme
Recorder record.EventRecorder
Config controller_common.Config
NatsAddr string
......@@ -257,10 +253,12 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req
}
// create or update api-server hpa
modified_, _, err = createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: dynamoComponentDeployment,
dynamoComponent: dynamoComponentCR,
}, r.generateHPA)
modified_, _, err = commonController.SyncResource(ctx, r, dynamoComponentDeployment, func(ctx context.Context) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) {
return r.generateHPA(generateResourceOption{
dynamoComponentDeployment: dynamoComponentDeployment,
dynamoComponent: dynamoComponentCR,
})
})
if err != nil {
return
}
......@@ -408,7 +406,7 @@ func (r *DynamoComponentDeploymentReconciler) reconcilePVC(ctx context.Context,
return nil, err
}
pvc = constructPVC(crd, pvcConfig)
if err := controllerutil.SetControllerReference(crd, pvc, r.Scheme); err != nil {
if err := controllerutil.SetControllerReference(crd, pvc, r.Client.Scheme()); err != nil {
logger.Error(err, "Failed to set controller reference", "pvc", pvc.Name)
return nil, err
}
......@@ -458,23 +456,27 @@ func (r *DynamoComponentDeploymentReconciler) setStatusConditions(ctx context.Co
func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments(ctx context.Context, opt generateResourceOption) (modified bool, depl *appsv1.Deployment, err error) {
containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment)
// create the main deployment
modified, depl, err = createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
}, r.generateDeployment)
modified, depl, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*appsv1.Deployment, bool, error) {
return r.generateDeployment(ctx, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
})
})
if err != nil {
err = errors.Wrap(err, "create or update deployment")
return
}
// create the debug deployment
modified2, _, err := createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: true,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
}, r.generateDeployment)
modified2, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*appsv1.Deployment, bool, error) {
return r.generateDeployment(ctx, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: true,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
})
})
if err != nil {
err = errors.Wrap(err, "create or update debug deployment")
}
......@@ -482,135 +484,6 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments(
return
}
//nolint:nakedret
func createOrUpdateResource[T client.Object](ctx context.Context, r *DynamoComponentDeploymentReconciler, opt generateResourceOption, generateResource func(ctx context.Context, opt generateResourceOption) (T, bool, error)) (modified bool, res T, err error) {
logs := log.FromContext(ctx)
resource, toDelete, err := generateResource(ctx, opt)
if err != nil {
return
}
resourceNamespace := resource.GetNamespace()
resourceName := resource.GetName()
resourceType := reflect.TypeOf(resource).Elem().Name()
logs = logs.WithValues("namespace", resourceNamespace, "resourceName", resourceName, "resourceType", resourceType)
// Retrieve the GroupVersionKind (GVK) of the desired object
gvk, err := apiutil.GVKForObject(resource, r.Client.Scheme())
if err != nil {
logs.Error(err, "Failed to get GVK for object")
return
}
// Create a new instance of the object
obj, err := r.Client.Scheme().New(gvk)
if err != nil {
logs.Error(err, "Failed to create a new object for GVK")
return
}
// Type assertion to ensure the object implements client.Object
oldResource, ok := obj.(T)
if !ok {
return
}
err = r.Get(ctx, types.NamespacedName{Name: resourceName, Namespace: resourceNamespace}, oldResource)
oldResourceIsNotFound := k8serrors.IsNotFound(err)
if err != nil && !oldResourceIsNotFound {
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Get%s", resourceType), "Failed to get %s %s: %s", resourceType, resourceNamespace, err)
logs.Error(err, "Failed to get HPA.")
return
}
err = nil
if oldResourceIsNotFound {
if toDelete {
logs.Info("Resource not found. Nothing to do.")
return
}
logs.Info("Resource not found. Creating a new one.")
err = errors.Wrapf(patch.DefaultAnnotator.SetLastAppliedAnnotation(resource), "set last applied annotation for resource %s", resourceName)
if err != nil {
logs.Error(err, "Failed to set last applied annotation.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, "SetLastAppliedAnnotation", "Failed to set last applied annotation for %s %s: %s", resourceType, resourceNamespace, err)
return
}
err = ctrl.SetControllerReference(opt.dynamoComponentDeployment, resource, r.Scheme)
if err != nil {
logs.Error(err, "Failed to set controller reference.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, "SetControllerReference", "Failed to set controller reference for %s %s: %s", resourceType, resourceNamespace, err)
return
}
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Creating a new %s %s", resourceType, resourceNamespace)
err = r.Create(ctx, resource)
if err != nil {
logs.Error(err, "Failed to create Resource.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Create%s", resourceType), "Failed to create %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s created.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Created %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s found.", resourceType))
if toDelete {
logs.Info(fmt.Sprintf("%s not found. Deleting the existing one.", resourceType))
err = r.Delete(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to delete %s.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Delete%s", resourceType), "Failed to delete %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s deleted.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Delete%s", resourceType), "Deleted %s %s", resourceType, resourceNamespace)
modified = true
return
}
var patchResult *patch.PatchResult
patchResult, err = patch.DefaultPatchMaker.Calculate(oldResource, resource)
if err != nil {
logs.Error(err, "Failed to calculate patch.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("CalculatePatch%s", resourceType), "Failed to calculate patch for %s %s: %s", resourceType, resourceNamespace, err)
return
}
if !patchResult.IsEmpty() {
logs.Info(fmt.Sprintf("%s spec is different. Updating %s. The patch result is: %s", resourceType, resourceType, patchResult.String()))
err = errors.Wrapf(patch.DefaultAnnotator.SetLastAppliedAnnotation(resource), "set last applied annotation for resource %s", resourceName)
if err != nil {
logs.Error(err, "Failed to set last applied annotation.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("SetLastAppliedAnnotation%s", resourceType), "Failed to set last applied annotation for %s %s: %s", resourceType, resourceNamespace, err)
return
}
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updating %s %s", resourceType, resourceNamespace)
resource.SetResourceVersion(oldResource.GetResourceVersion())
err = r.Update(ctx, resource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to update %s.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Update%s", resourceType), "Failed to update %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s updated.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updated %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s spec is the same. Skipping update.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Skipping update %s %s", resourceType, resourceNamespace)
res = oldResource
}
}
return
}
func getResourceAnnotations(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string {
resourceAnnotations := dynamoComponentDeployment.Spec.Annotations
if resourceAnnotations == nil {
......@@ -654,40 +527,46 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx
isDebugPodReceiveProductionTrafficEnabled := checkIfIsDebugPodReceiveProductionTrafficEnabled(resourceAnnotations)
containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment)
// main generic service
modified, _, err = createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: true,
}, r.generateService)
modified, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
return r.generateService(ctx, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: true,
})
})
if err != nil {
return
}
// debug production service (if enabled)
modified_, _, err := createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: false,
}, r.generateService)
modified_, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
return r.generateService(ctx, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: false,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: false,
})
})
if err != nil {
return
}
modified = modified || modified_
// debug service (if enabled)
modified_, _, err = createOrUpdateResource(ctx, r, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: true,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: false,
}, r.generateService)
modified_, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
return r.generateService(ctx, generateResourceOption{
dynamoComponentDeployment: opt.dynamoComponentDeployment,
dynamoComponent: opt.dynamoComponent,
isStealingTrafficDebugModeEnabled: true,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
isGenericService: false,
})
})
if err != nil {
return
}
......@@ -696,11 +575,15 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx
}
func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteIngress(ctx context.Context, opt generateResourceOption) (modified bool, err error) {
modified, _, err = createOrUpdateResource(ctx, r, opt, r.generateIngress)
modified, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*networkingv1.Ingress, bool, error) {
return r.generateIngress(ctx, opt)
})
if err != nil {
return
}
modified_, _, err := createOrUpdateResource(ctx, r, opt, r.generateVirtualService)
modified_, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*networkingv1beta1.VirtualService, bool, error) {
return r.generateVirtualService(ctx, opt)
})
if err != nil {
return
}
......@@ -964,7 +847,7 @@ type generateResourceOption struct {
isGenericService bool
}
func (r *DynamoComponentDeploymentReconciler) generateHPA(ctx context.Context, opt generateResourceOption) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) {
func (r *DynamoComponentDeploymentReconciler) generateHPA(opt generateResourceOption) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) {
labels := r.getKubeLabels(opt.dynamoComponentDeployment, opt.dynamoComponent)
annotations := r.getKubeAnnotations(opt.dynamoComponentDeployment, opt.dynamoComponent)
......@@ -1151,6 +1034,7 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
if opt.dynamoComponentDeployment.Spec.DynamoNamespace != nil && *opt.dynamoComponentDeployment.Spec.DynamoNamespace != "" {
args = append(args, fmt.Sprintf("--%s.ServiceArgs.dynamo.namespace=%s", opt.dynamoComponentDeployment.Spec.ServiceName, *opt.dynamoComponentDeployment.Spec.DynamoNamespace))
}
args = append(args, fmt.Sprintf("--%s.environment=%s", opt.dynamoComponentDeployment.Spec.ServiceName, KubernetesDeploymentStrategy))
}
if len(opt.dynamoComponentDeployment.Spec.Envs) > 0 {
......@@ -1505,7 +1389,7 @@ func getResourcesConfig(resources *dynamoCommon.Resources) (corev1.ResourceRequi
}
//nolint:nakedret
func (r *DynamoComponentDeploymentReconciler) generateService(ctx context.Context, opt generateResourceOption) (kubeService *corev1.Service, toDelete bool, err error) {
func (r *DynamoComponentDeploymentReconciler) generateService(_ context.Context, opt generateResourceOption) (kubeService *corev1.Service, toDelete bool, err error) {
var kubeName string
if opt.isGenericService {
kubeName = r.getGenericServiceName(opt.dynamoComponentDeployment, opt.dynamoComponent)
......@@ -1602,3 +1486,7 @@ func (r *DynamoComponentDeploymentReconciler) SetupWithManager(mgr ctrl.Manager)
m.Owns(&autoscalingv2.HorizontalPodAutoscaler{})
return m.Complete(r)
}
func (r *DynamoComponentDeploymentReconciler) GetRecorder() record.EventRecorder {
return r.Recorder
}
......@@ -19,12 +19,12 @@ package controller
import (
"context"
"encoding/json"
"fmt"
"dario.cat/mergo"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
......@@ -34,6 +34,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
)
......@@ -42,6 +43,8 @@ const (
FailedState = "failed"
ReadyState = "successful"
PendingState = "pending"
DYN_DEPLOYMENT_CONFIG_ENV_VAR = "DYN_DEPLOYMENT_CONFIG"
)
type etcdStorage interface {
......@@ -51,7 +54,6 @@ type etcdStorage interface {
// DynamoGraphDeploymentReconciler reconciles a DynamoGraphDeployment object
type DynamoGraphDeploymentReconciler struct {
client.Client
Scheme *runtime.Scheme
Config commonController.Config
Recorder record.EventRecorder
VirtualServiceGateway string
......@@ -94,6 +96,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if err != nil {
dynamoDeployment.SetState(FailedState)
message = err.Error()
logger.Error(err, "Reconciliation failed")
}
// update the CRD status condition
dynamoDeployment.AddStatusCondition(metav1.Condition{
......@@ -112,6 +115,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
deleted, err := commonController.HandleFinalizer(ctx, dynamoDeployment, r.Client, r)
if err != nil {
logger.Error(err, "failed to handle the finalizer")
reason = "failed_to_handle_the_finalizer"
return ctrl.Result{}, err
}
......@@ -122,6 +126,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// fetch the dynamoGraphConfig
dynamoGraphConfig, err := dynamo.GetDynamoGraphConfig(ctx, dynamoDeployment, r.Recorder)
if err != nil {
logger.Error(err, "failed to get the DynamoGraphConfig")
reason = "failed_to_get_the_DynamoGraphConfig"
return ctrl.Result{}, err
}
......@@ -129,6 +134,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// generate the dynamoComponentsDeployments from the config
dynamoComponentsDeployments, err := dynamo.GenerateDynamoComponentsDeployments(ctx, dynamoDeployment, dynamoGraphConfig, r.generateDefaultIngressSpec(dynamoDeployment))
if err != nil {
logger.Error(err, "failed to generate the DynamoComponentsDeployments")
reason = "failed_to_generate_the_DynamoComponentsDeployments"
return ctrl.Result{}, err
}
......@@ -138,6 +144,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if _, ok := dynamoDeployment.Spec.Services[serviceName]; ok {
err := mergo.Merge(&deployment.Spec.DynamoComponentDeploymentSharedSpec, dynamoDeployment.Spec.Services[serviceName].DynamoComponentDeploymentSharedSpec, mergo.WithOverride)
if err != nil {
logger.Error(err, "failed to merge the DynamoComponentsDeployments")
reason = "failed_to_merge_the_DynamoComponentsDeployments"
return ctrl.Result{}, err
}
......@@ -152,6 +159,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if len(dynamoDeployment.Spec.Envs) > 0 {
deployment.Spec.Envs = mergeEnvs(dynamoDeployment.Spec.Envs, deployment.Spec.Envs)
}
err := updateDynDeploymentConfig(deployment, consts.DynamoServicePort)
if err != nil {
logger.Error(err, fmt.Sprintf("Failed to update the %v env var", DYN_DEPLOYMENT_CONFIG_ENV_VAR))
return ctrl.Result{}, err
}
}
// reconcile the dynamoComponent
......@@ -165,12 +177,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
DynamoComponent: dynamoDeployment.Spec.DynamoGraph,
},
}
if err := ctrl.SetControllerReference(dynamoDeployment, dynamoComponent, r.Scheme); err != nil {
reason = "failed_to_set_the_controller_reference_for_the_DynamoComponent"
return ctrl.Result{}, err
}
dynamoComponent, err = commonController.SyncResource(ctx, r.Client, dynamoComponent, false)
_, dynamoComponent, err = commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*nvidiacomv1alpha1.DynamoComponent, bool, error) {
return dynamoComponent, false, nil
})
if err != nil {
logger.Error(err, "failed to sync the DynamoComponent")
reason = "failed_to_sync_the_DynamoComponent"
return ctrl.Result{}, err
}
......@@ -186,12 +197,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// reconcile the dynamoComponentsDeployments
for serviceName, dynamoComponentDeployment := range dynamoComponentsDeployments {
logger.Info("Reconciling the DynamoComponentDeployment", "serviceName", serviceName, "dynamoComponentDeployment", dynamoComponentDeployment)
if err := ctrl.SetControllerReference(dynamoDeployment, dynamoComponentDeployment, r.Scheme); err != nil {
reason = "failed_to_set_the_controller_reference_for_the_DynamoComponentDeployment"
return ctrl.Result{}, err
}
dynamoComponentDeployment, err = commonController.SyncResource(ctx, r.Client, dynamoComponentDeployment, false)
_, dynamoComponentDeployment, err = commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*nvidiacomv1alpha1.DynamoComponentDeployment, bool, error) {
return dynamoComponentDeployment, false, nil
})
if err != nil {
logger.Error(err, "failed to sync the DynamoComponentDeployment")
reason = "failed_to_sync_the_DynamoComponentDeployment"
return ctrl.Result{}, err
}
......@@ -265,6 +275,39 @@ func mergeEnvs(common, specific []corev1.EnvVar) []corev1.EnvVar {
return merged
}
// updateDynDeploymentConfig updates the DYN_DEPLOYMENT_CONFIG env var for the given dynamoDeploymentComponent
// It updates the port for the given service in the DYN_DEPLOYMENT_CONFIG env var (if it is the main component)
func updateDynDeploymentConfig(dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment, newPort int) error {
if dynamoDeploymentComponent.IsMainComponent() {
for i, env := range dynamoDeploymentComponent.Spec.Envs {
if env.Name == DYN_DEPLOYMENT_CONFIG_ENV_VAR {
var config map[string]any
if err := json.Unmarshal([]byte(env.Value), &config); err != nil {
return fmt.Errorf("failed to unmarshal %v: %w", DYN_DEPLOYMENT_CONFIG_ENV_VAR, err)
}
// Safely navigate and update the config
if serviceConfig, ok := config[dynamoDeploymentComponent.Spec.ServiceName].(map[string]any); ok {
if _, portExists := serviceConfig["port"]; portExists {
serviceConfig["port"] = newPort
}
}
// Marshal back to JSON string
updated, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal updated config: %w", err)
}
// Update env var
dynamoDeploymentComponent.Spec.Envs[i].Value = string(updated)
break
}
}
}
return nil
}
func (r *DynamoGraphDeploymentReconciler) FinalizeResource(ctx context.Context, dynamoDeployment *nvidiacomv1alpha1.DynamoGraphDeployment) error {
// for now doing nothing
return nil
......@@ -294,3 +337,7 @@ func (r *DynamoGraphDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) err
WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config)).
Complete(r)
}
func (r *DynamoGraphDeploymentReconciler) GetRecorder() record.EventRecorder {
return r.Recorder
}
......@@ -22,6 +22,8 @@ import (
"sort"
"testing"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/bsm/gomega"
corev1 "k8s.io/api/core/v1"
)
......@@ -80,3 +82,125 @@ func Test_mergeEnvs(t *testing.T) {
})
}
}
func Test_updateDynDeploymentConfig(t *testing.T) {
type args struct {
dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment
newPort int
}
tests := []struct {
name string
args args
want []corev1.EnvVar
wantErr bool
}{
{
name: "main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 3000,
},
want: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":3000},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
{
name: "not main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Other",
Envs: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 3000,
},
want: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
{
name: "no DYN_DEPLOYMENT_CONFIG env variable",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 8080,
},
want: []corev1.EnvVar{
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := updateDynDeploymentConfig(tt.args.dynamoDeploymentComponent, tt.args.newPort)
if (err != nil) != tt.wantErr {
t.Errorf("updateDynDeploymentConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
g := gomega.NewGomegaWithT(t)
g.Expect(tt.args.dynamoDeploymentComponent.Spec.Envs).To(gomega.Equal(tt.want))
})
}
}
......@@ -22,12 +22,19 @@ import (
"crypto/sha256"
"encoding/json"
"fmt"
"reflect"
"sort"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/log"
)
const (
......@@ -35,76 +42,241 @@ const (
NvidiaAnnotationHashKey = "nvidia.com/last-applied-hash"
)
type Resource interface {
client.Object
GetSpec() any
SetSpec(spec any)
type Reconciler interface {
client.Client
GetRecorder() record.EventRecorder
}
func SyncResource[T Resource](ctx context.Context, c client.Client, desired T, createOnly bool) (T, error) {
// ResourceGenerator is a function that generates a resource.
// it must return the resource, a boolean indicating if the resource should be deleted, and an error
// if the resource should be deleted, the returned resource must contain the necessary information to delete it (name and namespace)
type ResourceGenerator[T client.Object] func(ctx context.Context) (T, bool, error)
//nolint:nakedret
func SyncResource[T client.Object](ctx context.Context, r Reconciler, parentResource client.Object, generateResource ResourceGenerator[T]) (modified bool, res T, err error) {
logs := log.FromContext(ctx)
resource, toDelete, err := generateResource(ctx)
if err != nil {
return
}
resourceNamespace := resource.GetNamespace()
resourceName := resource.GetName()
resourceType := reflect.TypeOf(resource).Elem().Name()
logs = logs.WithValues("namespace", resourceNamespace, "resourceName", resourceName, "resourceType", resourceType)
// Retrieve the GroupVersionKind (GVK) of the desired object
gvk, err := apiutil.GVKForObject(desired, c.Scheme())
gvk, err := apiutil.GVKForObject(resource, r.Scheme())
if err != nil {
return desired, fmt.Errorf("failed to get GVK for object: %w", err)
logs.Error(err, "Failed to get GVK for object")
return
}
// Create a new instance of the object
obj, err := c.Scheme().New(gvk)
obj, err := r.Scheme().New(gvk)
if err != nil {
return desired, fmt.Errorf("failed to create a new object for GVK %s: %w", gvk, err)
logs.Error(err, "Failed to create a new object for GVK")
return
}
// Type assertion to ensure the object implements client.Object
current, ok := obj.(T)
oldResource, ok := obj.(T)
if !ok {
return desired, fmt.Errorf("failed to cast object to the expected type %T", desired)
return
}
namespacedName := types.NamespacedName{
Name: desired.GetName(),
Namespace: desired.GetNamespace(),
err = r.Get(ctx, types.NamespacedName{Name: resourceName, Namespace: resourceNamespace}, oldResource)
oldResourceIsNotFound := errors.IsNotFound(err)
if err != nil && !oldResourceIsNotFound {
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Get%s", resourceType), "Failed to get %s %s: %s", resourceType, resourceNamespace, err)
logs.Error(err, "Failed to get resource.")
return
}
err = nil
// Retrieve the existing resource
err = c.Get(ctx, namespacedName, current)
if err != nil {
if errors.IsNotFound(err) {
// If the resource doesn't exist, create it
if err := c.Create(ctx, desired); err != nil {
return desired, fmt.Errorf("failed to create resource: %w", err)
if oldResourceIsNotFound {
if toDelete {
logs.Info("Resource not found. Nothing to do.")
return
}
logs.Info("Resource not found. Creating a new one.")
err = ctrl.SetControllerReference(parentResource, resource, r.Scheme())
if err != nil {
logs.Error(err, "Failed to set controller reference.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "SetControllerReference", "Failed to set controller reference for %s %s: %s", resourceType, resourceNamespace, err)
return
}
var hash string
hash, err = GetSpecHash(resource)
if err != nil {
logs.Error(err, "Failed to get spec hash.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "GetSpecHash", "Failed to get spec hash for %s %s: %s", resourceType, resourceNamespace, err)
return
}
updateHashAnnotation(resource, hash)
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Creating a new %s %s", resourceType, resourceNamespace)
err = r.Create(ctx, resource)
if err != nil {
logs.Error(err, "Failed to create Resource.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Create%s", resourceType), "Failed to create %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s created.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Created %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s found.", resourceType))
if toDelete {
logs.Info(fmt.Sprintf("%s not found. Deleting the existing one.", resourceType))
err = r.Delete(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to delete %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Delete%s", resourceType), "Failed to delete %s %s: %s", resourceType, resourceNamespace, err)
return
}
return desired, nil
logs.Info(fmt.Sprintf("%s deleted.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Delete%s", resourceType), "Deleted %s %s", resourceType, resourceNamespace)
modified = true
return
}
// Check if the Spec has changed and update if necessary
var newHash *string
newHash, err = IsSpecChanged(oldResource, resource)
if err != nil {
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CalculatePatch%s", resourceType), "Failed to calculate patch for %s %s: %s", resourceType, resourceNamespace, err)
return false, resource, fmt.Errorf("failed to check if spec has changed: %w", err)
}
return desired, fmt.Errorf("failed to get resource: %w", err)
if newHash != nil {
// update the spec of the current object with the desired spec
err = CopySpec(resource, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to copy spec for %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CopySpec%s", resourceType), "Failed to copy spec for %s %s: %s", resourceType, resourceNamespace, err)
return
}
updateHashAnnotation(oldResource, *newHash)
err = r.Update(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to update %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Update%s", resourceType), "Failed to update %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s updated.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updated %s %s", resourceType, resourceNamespace)
modified = true
res = oldResource
} else {
logs.Info(fmt.Sprintf("%s spec is the same. Skipping update.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Skipping update %s %s", resourceType, resourceNamespace)
res = oldResource
}
}
return
}
// CopySpec copies only the Spec field from source to destination using Unstructured
func CopySpec(source, destination client.Object) error {
// Convert source to unstructured
sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(source)
if err != nil {
return err
}
sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}
// Convert destination to unstructured
destMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(destination)
if err != nil {
return err
}
destUnstructured := &unstructured.Unstructured{Object: destMap}
// Extract only the spec from source
sourceSpec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
if err != nil {
return err
}
if !found {
return fmt.Errorf("spec not found in source object")
}
if createOnly {
return current, nil
// Set the spec in the destination
err = unstructured.SetNestedField(destUnstructured.Object, sourceSpec, "spec")
if err != nil {
return err
}
// Check if the Spec has changed and update if necessary
if IsSpecChanged(current, desired) {
// update the spec of the current object with the desired spec
desired.SetResourceVersion(current.GetResourceVersion())
if err := c.Update(ctx, desired); err != nil {
return desired, fmt.Errorf("failed to update resource: %w", err)
// Convert back to the original object
return runtime.DefaultUnstructuredConverter.FromUnstructured(destUnstructured.Object, destination)
}
func getSpec(obj client.Object) (any, error) {
// Convert source to unstructured
sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj)
if err != nil {
return nil, err
}
sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}
// Extract only the spec from source
spec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
if err != nil {
return nil, err
}
if !found {
return nil, nil
}
return spec, nil
}
// IsSpecChanged returns the new hash if the spec has changed between the existing one
func IsSpecChanged(current client.Object, desired client.Object) (*string, error) {
hashStr, err := GetSpecHash(desired)
if err != nil {
return nil, err
}
if currentHash, ok := current.GetAnnotations()[NvidiaAnnotationHashKey]; ok {
if currentHash == hashStr {
return nil, nil
}
}
return &hashStr, nil
}
// Return the updated object
return current, nil
func GetSpecHash(obj client.Object) (string, error) {
spec, err := getSpec(obj)
if err != nil {
return "", err
}
return GetResourceHash(spec)
}
func updateHashAnnotation(obj client.Object, hash string) {
annotations := obj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{}
}
annotations[NvidiaAnnotationHashKey] = hash
obj.SetAnnotations(annotations)
}
// GetResourceHash returns a consistent hash for the given object spec
func GetResourceHash(obj any) string {
func GetResourceHash(obj any) (string, error) {
// Convert obj to a map[string]interface{}
objMap, err := json.Marshal(obj)
if err != nil {
panic(err)
return "", err
}
var objData map[string]interface{}
if err := json.Unmarshal(objMap, &objData); err != nil {
panic(err)
return "", err
}
// Sort keys to ensure consistent serialization
......@@ -113,56 +285,13 @@ func GetResourceHash(obj any) string {
// Serialize to JSON
serialized, err := json.Marshal(sortedObjData)
if err != nil {
panic(err)
return "", err
}
// Compute the hash
hasher := sha256.New()
hasher.Write(serialized)
return fmt.Sprintf("%x", hasher.Sum(nil))
}
// IsSpecChanged returns true if the spec has changed between the existing one
// and the new resource spec compared by hash.
func IsSpecChanged(current Resource, desired Resource) bool {
if current == nil && desired != nil {
return true
}
hashStr := GetResourceHash(desired.GetSpec())
foundHashAnnotation := false
currentAnnotations := current.GetAnnotations()
desiredAnnotations := desired.GetAnnotations()
if currentAnnotations == nil {
currentAnnotations = map[string]string{}
}
if desiredAnnotations == nil {
desiredAnnotations = map[string]string{}
}
for annotation, value := range currentAnnotations {
if annotation == NvidiaAnnotationHashKey {
if value != hashStr {
// Update annotation to be added to resource as per new spec and indicate spec update is required
desiredAnnotations[NvidiaAnnotationHashKey] = hashStr
desired.SetAnnotations(desiredAnnotations)
return true
}
foundHashAnnotation = true
break
}
}
if !foundHashAnnotation {
// Update annotation to be added to resource as per new spec and indicate spec update is required
desiredAnnotations[NvidiaAnnotationHashKey] = hashStr
desired.SetAnnotations(desiredAnnotations)
return true
}
return false
return fmt.Sprintf("%x", hasher.Sum(nil)), nil
}
// SortKeys recursively sorts the keys of a map to ensure consistent serialization
......
......@@ -27,7 +27,7 @@ import (
"emperror.dev/errors"
apiStoreClient "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/api_store_client"
compounaiCommon "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconfig "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/config"
......@@ -42,11 +42,17 @@ import (
"gopkg.in/yaml.v2"
)
const (
ComponentTypePlanner = "planner"
PlannerServiceAccountName = "planner-serviceaccount"
)
// ServiceConfig represents the YAML configuration structure for a service
type DynamoConfig struct {
Enabled bool `yaml:"enabled"`
Namespace string `yaml:"namespace"`
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Namespace string `yaml:"namespace"`
Name string `yaml:"name"`
ComponentType string `yaml:"component_type,omitempty"`
}
type Resources struct {
......@@ -230,6 +236,7 @@ func GetDynamoGraphConfig(ctx context.Context, dynamoDeployment *v1alpha1.Dynamo
func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphDeployment *v1alpha1.DynamoGraphDeployment, config *DynamoGraphConfig, ingressSpec *v1alpha1.IngressSpec) (map[string]*v1alpha1.DynamoComponentDeployment, error) {
dynamoServices := make(map[string]string)
deployments := make(map[string]*v1alpha1.DynamoComponentDeployment)
graphDynamoNamespace := ""
for _, service := range config.Services {
deployment := &v1alpha1.DynamoComponentDeployment{}
deployment.Name = fmt.Sprintf("%s-%s", parentDynamoGraphDeployment.Name, strings.ToLower(service.Name))
......@@ -252,6 +259,18 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD
deployment.Spec.DynamoNamespace = &dynamoNamespace
dynamoServices[service.Name] = fmt.Sprintf("%s/%s", service.Config.Dynamo.Name, dynamoNamespace)
labels[commonconsts.KubeLabelDynamoNamespace] = dynamoNamespace
// we check that all dynamo components are in the same namespace
// this is needed for the planner to work correctly
// this check will be removed when the global planner will be implemented
if graphDynamoNamespace != "" && graphDynamoNamespace != dynamoNamespace {
return nil, fmt.Errorf("different namespaces for the same graph, expected %s, got %s", graphDynamoNamespace, dynamoNamespace)
}
graphDynamoNamespace = dynamoNamespace
if service.Config.Dynamo.ComponentType == ComponentTypePlanner {
deployment.Spec.ExtraPodSpec = &common.ExtraPodSpec{
ServiceAccountName: PlannerServiceAccountName,
}
}
}
// Check http_exposed independently
if config.EntryService == service.Name && service.Config.HttpExposed {
......@@ -260,14 +279,14 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD
}
if service.Config.Resources != nil {
deployment.Spec.Resources = &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
deployment.Spec.Resources = &common.Resources{
Requests: &common.ResourceItem{
CPU: service.Config.Resources.CPU,
Memory: service.Config.Resources.Memory,
GPU: service.Config.Resources.GPU,
Custom: service.Config.Resources.Custom,
},
Limits: &compounaiCommon.ResourceItem{
Limits: &common.ResourceItem{
CPU: service.Config.Resources.CPU,
Memory: service.Config.Resources.Memory,
GPU: service.Config.Resources.GPU,
......
......@@ -507,6 +507,141 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
},
wantErr: true,
},
{
name: "Test GenerateDynamoComponentsDeployments planner",
args: args{
parentDynamoGraphDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
DynamoGraph: "dynamocomponent:ac4e234",
},
},
config: &DynamoGraphConfig{
DynamoTag: "dynamocomponent:MyService1",
Services: []ServiceConfig{
{
Name: "service1",
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "default",
Name: "service1",
ComponentType: ComponentTypePlanner,
},
Resources: &Resources{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
},
},
},
},
ingressSpec: &v1alpha1.IngressSpec{},
},
want: map[string]*v1alpha1.DynamoComponentDeployment{
"service1": {
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment-service1",
Namespace: "default",
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponent: "service1",
commonconsts.KubeLabelDynamoNamespace: "default",
},
},
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponent: "dynamocomponent:ac4e234",
DynamoTag: "dynamocomponent:MyService1",
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "service1",
DynamoNamespace: &[]string{"default"}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
Limits: &compounaiCommon.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
},
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponent: "service1",
commonconsts.KubeLabelDynamoNamespace: "default",
},
ExtraPodSpec: &compounaiCommon.ExtraPodSpec{
ServiceAccountName: PlannerServiceAccountName,
},
Autoscaling: &v1alpha1.Autoscaling{},
},
},
},
},
wantErr: false,
},
{
name: "Test GenerateDynamoComponentsDeployments dynamo dependency, different namespace",
args: args{
parentDynamoGraphDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
DynamoGraph: "dynamocomponent:ac4e234",
},
},
config: &DynamoGraphConfig{
DynamoTag: "dynamocomponent:MyService2",
EntryService: "service1",
Services: []ServiceConfig{
{
Name: "service1",
Dependencies: []map[string]string{{"service": "service2"}},
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "namespace1",
Name: "service1",
},
Resources: &Resources{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
Autoscaling: &Autoscaling{
MinReplicas: 1,
MaxReplicas: 5,
},
},
},
{
Name: "service2",
Dependencies: []map[string]string{},
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "namespace2",
Name: "service2",
},
},
},
},
},
ingressSpec: &v1alpha1.IngressSpec{},
},
want: nil,
wantErr: true,
},
{
name: "Test GenerateDynamoComponentsDeployments ingress enabled by default",
args: args{
......
......@@ -173,6 +173,11 @@ def main(
if service_name and service_name != service.name:
service = service.find_dependent_by_name(service_name)
# Set namespace in dynamo_context if service is a dynamo component
if service.is_dynamo_component():
namespace, _ = service.dynamo_address()
dynamo_context["namespace"] = namespace
configure_dynamo_logging(service_name=service_name, worker_id=worker_id)
if runner_map:
BentoMLContainer.remote_runner_mapping.set(
......
......@@ -19,6 +19,7 @@ import logging
import os
from collections import defaultdict
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
# WARNING: internal
......@@ -34,6 +35,16 @@ T = TypeVar("T", bound=object)
logger = logging.getLogger(__name__)
class ComponentType(str, Enum):
"""Types of Dynamo components"""
PLANNER = "planner"
# Future types can be added here like:
# METRICS = "metrics"
# MONITOR = "monitor"
# etc.
class RuntimeLinkedServices:
"""
A class to track the linked services in the runtime.
......@@ -68,6 +79,9 @@ class DynamoConfig:
name: str | None = None
namespace: str | None = None
custom_lease: LeaseConfig | None = None
component_type: ComponentType | None = (
None # Indicates if this is a meta/system component
)
@dataclass
......
......@@ -17,6 +17,7 @@ import logging
import subprocess
from pathlib import Path
from components.planner_service import Planner
from components.processor import Processor
from components.worker import VllmWorker
from fastapi import FastAPI
......@@ -60,6 +61,7 @@ class FrontendConfig(BaseModel):
app=FastAPI(title="LLM Example"),
)
class Frontend:
planner = depends(Planner)
worker = depends(VllmWorker)
processor = depends(Processor)
......
......@@ -30,7 +30,7 @@ from tensorboardX import SummaryWriter
from utils.prefill_queue import PrefillQueue
from dynamo.llm import KvMetricsAggregator
from dynamo.planner import LocalConnector
from dynamo.planner import KubernetesConnector, LocalConnector
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -53,7 +53,12 @@ class Planner:
self.runtime = runtime
self.args = args
self.namespace = args.namespace
self.connector = LocalConnector(args.namespace, runtime)
if args.environment == "local":
self.connector = LocalConnector(args.namespace, runtime)
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(args.namespace)
else:
raise ValueError(f"Invalid environment: {args.environment}")
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
......@@ -355,7 +360,8 @@ class Planner:
await asyncio.sleep(self.args.metric_pulling_interval / 10)
@dynamo_worker()
# @dynamo_worker()
# TODO: let's make it such that planner still works via CLI invokation
async def start_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args)
console = Console()
......@@ -465,5 +471,11 @@ if __name__ == "__main__":
default=1,
help="Number of GPUs per prefill engine",
)
parser.add_argument(
"--environment",
type=str,
default="local",
help="Environment to run the planner in (local, kubernetes)",
)
args = parser.parse_args()
asyncio.run(start_planner(args))
asyncio.run(dynamo_worker()(start_planner)(args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pydantic import BaseModel
from components.planner import start_planner # type: ignore[attr-defined]
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
class RequestType(BaseModel):
text: str
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"component_type": "planner",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
)
class Planner:
def __init__(self):
configure_dynamo_logging(service_name="Planner")
logger.info("Starting planner")
self.runtime = dynamo_context["runtime"]
config = ServiceConfig.get_instance()
# Get namespace directly from dynamo_context as it contains the active namespace
self.namespace = dynamo_context["namespace"]
self.environment = config.get("Planner", {}).get("environment", "local")
self.no_operation = config.get("Planner", {}).get("no-operation", True)
# Create args with all parameters from planner.py, using defaults except for namespace and environment
self.args = argparse.Namespace(
namespace=self.namespace,
environment=self.environment,
served_model_name="vllm",
no_operation=self.no_operation,
log_dir=None,
adjustment_interval=10,
metric_pulling_interval=1,
max_gpu_budget=8,
min_endpoint=1,
decode_kv_scale_up_threshold=0.9,
decode_kv_scale_down_threshold=0.5,
prefill_queue_scale_up_threshold=5,
prefill_queue_scale_down_threshold=0.2,
decode_engine_num_gpu=1,
prefill_engine_num_gpu=1,
)
@async_on_start
async def async_init(self):
import asyncio
await asyncio.sleep(60)
logger.info("Calling start_planner")
await start_planner(self.runtime, self.args)
logger.info("Planner started")
@dynamo_endpoint()
async def generate(self, request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment