Unverified Commit b2aa2317 authored by julienmancuso's avatar julienmancuso Committed by GitHub
Browse files

feat: deploy planner in operator (#921)


Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 57975b27
...@@ -13,9 +13,15 @@ ...@@ -13,9 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ["CircusController", "LocalConnector", "PlannerConnector"] __all__ = [
"CircusController",
"LocalConnector",
"PlannerConnector",
"KubernetesConnector",
]
# Import the classes # Import the classes
from dynamo.planner.circusd import CircusController from dynamo.planner.circusd import CircusController
from dynamo.planner.kubernetes_connector import KubernetesConnector
from dynamo.planner.local_connector import LocalConnector from dynamo.planner.local_connector import LocalConnector
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from kubernetes import client, config
class KubernetesAPI:
def __init__(self):
# Load kubernetes configuration
config.load_incluster_config() # for in-cluster deployment
self.custom_api = client.CustomObjectsApi()
self.current_namespace = self._get_current_namespace()
def _get_current_namespace(self) -> str:
"""Get the current namespace if running inside a k8s cluster"""
try:
with open(
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r"
) as f:
return f.read().strip()
except FileNotFoundError:
# Fallback to 'default' if not running in k8s
return "default"
async def get_graph_deployment(
self, component_name: str, dynamo_namespace: str
) -> Optional[dict]:
"""
Get DynamoGraphDeployment by first finding the associated DynamoComponentDeployment
and then retrieving its owner reference.
Args:
component_name: The name of the component
dynamo_namespace: The dynamo namespace
Returns:
The DynamoGraphDeployment object or None if not found
"""
try:
# First, find the DynamoComponentDeployment using the component name and namespace labels
label_selector = f"nvidia.com/dynamo-component={component_name},nvidia.com/dynamo-namespace={dynamo_namespace}"
component_deployments = self.custom_api.list_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamocomponentdeployments",
label_selector=label_selector,
)
items = component_deployments.get("items", [])
if not items:
return None
if len(items) > 1:
raise ValueError(
f"Multiple component deployments found for component {component_name} in dynamo namespace {dynamo_namespace}. "
"Expected exactly one deployment."
)
# Get the component deployment and extract the owner reference
component_deployment = items[0]
owner_refs = component_deployment.get("metadata", {}).get(
"ownerReferences", []
)
# Find the DynamoGraphDeployment in the owner references
graph_deployment_ref = None
for ref in owner_refs:
if (
ref.get("apiVersion") == "nvidia.com/v1alpha1"
and ref.get("kind") == "DynamoGraphDeployment"
):
graph_deployment_ref = ref
break
if not graph_deployment_ref:
return None
# Get the actual DynamoGraphDeployment using the name from the owner reference
graph_deployment_name = graph_deployment_ref.get("name")
if not graph_deployment_name:
return None
graph_deployment = self.custom_api.get_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamographdeployments",
name=graph_deployment_name,
)
return graph_deployment
except client.ApiException as e:
if e.status == 404:
return None
raise
async def update_graph_replicas(
self, graph_deployment_name: str, component_name: str, replicas: int
) -> None:
"""Update the replicas count for a component in a DynamoGraphDeployment"""
patch = {"spec": {"services": {component_name: {"replicas": replicas}}}}
self.custom_api.patch_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamographdeployments",
name=graph_deployment_name,
body=patch,
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .kube import KubernetesAPI
from .planner_connector import PlannerConnector
class KubernetesConnector(PlannerConnector):
def __init__(self, namespace: str):
self.kube_api = KubernetesAPI()
self.namespace = namespace
async def add_component(self, component_name: str):
"""Add a component by increasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace
)
if deployment is None:
raise ValueError(
f"Graph not found for component {component_name} in dynamo namespace {self.namespace}"
)
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas + 1,
)
async def remove_component(self, component_name: str):
"""Remove a component by decreasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace
)
if deployment is None:
raise ValueError(
f"Graph {component_name} not found for namespace {self.namespace}"
)
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
if current_replicas > 0:
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas - 1,
)
def _get_current_replicas(self, deployment: dict, component_name: str) -> int:
"""Get the current replicas for a component in a graph deployment"""
return (
deployment.get("spec", {})
.get("services", {})
.get(component_name, {})
.get("replicas", 1)
)
def _get_graph_deployment_name(self, deployment: dict) -> str:
"""Get the name of the graph deployment"""
return deployment["metadata"]["name"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock, Mock
import pytest
from dynamo.planner.kubernetes_connector import KubernetesConnector
@pytest.fixture
def mock_kube_api():
mock_api = Mock()
mock_api.get_graph_deployment = AsyncMock()
mock_api.update_graph_replicas = AsyncMock()
return mock_api
@pytest.fixture
def mock_kube_api_class(mock_kube_api):
mock_class = Mock()
mock_class.return_value = mock_kube_api
return mock_class
@pytest.fixture
def kubernetes_connector(mock_kube_api_class, monkeypatch):
# Patch the KubernetesAPI class before instantiating the connector
monkeypatch.setattr(
"dynamo.planner.kubernetes_connector.KubernetesAPI", mock_kube_api_class
)
connector = KubernetesConnector()
# Set the namespace attribute that's being accessed in the error
connector.namespace = "default"
return connector
@pytest.mark.asyncio
async def test_add_component_increases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 1}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.add_component(component_name)
# Assert
mock_kube_api.get_graph_deployment.assert_called_once_with(component_name)
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2
)
@pytest.mark.asyncio
async def test_add_component_with_no_replicas_specified(
kubernetes_connector, mock_kube_api
):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.add_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2
)
@pytest.mark.asyncio
async def test_add_component_deployment_not_found(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_kube_api.get_graph_deployment.return_value = None
# Act & Assert
with pytest.raises(
ValueError, match=f"Graph not found for component {component_name}"
):
await kubernetes_connector.add_component(component_name)
@pytest.mark.asyncio
async def test_remove_component_decreases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 2}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 1
)
@pytest.mark.asyncio
async def test_remove_component_with_zero_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 0}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
# Assert
mock_kube_api.update_graph_replicas.assert_not_called()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: planner-serviceaccount
namespace: {{ .Values.namespace }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: planner-role
namespace: {{ .Values.namespace }}
rules:
- apiGroups: ["nvidia.com"]
resources: ["dynamocomponentdeployments", "dynamographdeployments"]
verbs: ["get", "list", "create", "update", "patch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: planner-binding
namespace: {{ .Values.namespace }}
subjects:
- kind: ServiceAccount
name: planner-serviceaccount
namespace: {{ .Values.namespace }}
roleRef:
kind: Role
name: planner-role
apiGroup: rbac.authorization.k8s.io
\ No newline at end of file
...@@ -182,7 +182,6 @@ func main() { ...@@ -182,7 +182,6 @@ func main() {
} }
if err = (&controller.DynamoComponentDeploymentReconciler{ if err = (&controller.DynamoComponentDeploymentReconciler{
Client: mgr.GetClient(), Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("dynamocomponentdeployment"), Recorder: mgr.GetEventRecorderFor("dynamocomponentdeployment"),
Config: ctrlConfig, Config: ctrlConfig,
NatsAddr: natsAddr, NatsAddr: natsAddr,
...@@ -204,7 +203,6 @@ func main() { ...@@ -204,7 +203,6 @@ func main() {
} }
if err = (&controller.DynamoGraphDeploymentReconciler{ if err = (&controller.DynamoGraphDeploymentReconciler{
Client: mgr.GetClient(), Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("dynamographdeployment"), Recorder: mgr.GetEventRecorderFor("dynamographdeployment"),
Config: ctrlConfig, Config: ctrlConfig,
VirtualServiceGateway: istioVirtualServiceGateway, VirtualServiceGateway: istioVirtualServiceGateway,
......
...@@ -9,7 +9,6 @@ require ( ...@@ -9,7 +9,6 @@ require (
emperror.dev/errors v0.8.1 emperror.dev/errors v0.8.1
github.com/apparentlymart/go-shquot v0.0.1 github.com/apparentlymart/go-shquot v0.0.1
github.com/bsm/gomega v1.27.10 github.com/bsm/gomega v1.27.10
github.com/cisco-open/k8s-objectmatcher v1.9.0
github.com/huandu/xstrings v1.4.0 github.com/huandu/xstrings v1.4.0
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/onsi/ginkgo/v2 v2.19.0 github.com/onsi/ginkgo/v2 v2.19.0
...@@ -95,7 +94,6 @@ require ( ...@@ -95,7 +94,6 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/grpc v1.65.0 // indirect google.golang.org/grpc v1.65.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect
......
...@@ -10,8 +10,6 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= ...@@ -10,8 +10,6 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cisco-open/k8s-objectmatcher v1.9.0 h1:/sfuO0BD09fpynZjXsqeZrh28Juc4VEwc2P6Ov/Q6fM=
github.com/cisco-open/k8s-objectmatcher v1.9.0/go.mod h1:CH4E6qAK+q+JwKFJn0DaTNqxrbmWCaDQzGthKLK4nZ0=
github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4=
github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
......
...@@ -23,7 +23,6 @@ import ( ...@@ -23,7 +23,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"reflect"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
...@@ -43,14 +42,12 @@ import ( ...@@ -43,14 +42,12 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/cisco-open/k8s-objectmatcher/patch"
"github.com/huandu/xstrings" "github.com/huandu/xstrings"
istioNetworking "istio.io/api/networking/v1beta1" istioNetworking "istio.io/api/networking/v1beta1"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
k8serrors "k8s.io/apimachinery/pkg/api/errors" k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/tools/record" "k8s.io/client-go/tools/record"
...@@ -58,7 +55,6 @@ import ( ...@@ -58,7 +55,6 @@ import (
ctrl "sigs.k8s.io/controller-runtime" ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log"
...@@ -77,12 +73,12 @@ const ( ...@@ -77,12 +73,12 @@ const (
DeploymentTargetTypeDebug = "debug" DeploymentTargetTypeDebug = "debug"
HeaderNameDebug = "X-Nvidia-Debug" HeaderNameDebug = "X-Nvidia-Debug"
DefaultIngressSuffix = "local" DefaultIngressSuffix = "local"
KubernetesDeploymentStrategy = "kubernetes"
) )
// DynamoComponentDeploymentReconciler reconciles a DynamoComponentDeployment object // DynamoComponentDeploymentReconciler reconciles a DynamoComponentDeployment object
type DynamoComponentDeploymentReconciler struct { type DynamoComponentDeploymentReconciler struct {
client.Client client.Client
Scheme *runtime.Scheme
Recorder record.EventRecorder Recorder record.EventRecorder
Config controller_common.Config Config controller_common.Config
NatsAddr string NatsAddr string
...@@ -257,10 +253,12 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req ...@@ -257,10 +253,12 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req
} }
// create or update api-server hpa // create or update api-server hpa
modified_, _, err = createOrUpdateResource(ctx, r, generateResourceOption{ modified_, _, err = commonController.SyncResource(ctx, r, dynamoComponentDeployment, func(ctx context.Context) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) {
dynamoComponentDeployment: dynamoComponentDeployment, return r.generateHPA(generateResourceOption{
dynamoComponent: dynamoComponentCR, dynamoComponentDeployment: dynamoComponentDeployment,
}, r.generateHPA) dynamoComponent: dynamoComponentCR,
})
})
if err != nil { if err != nil {
return return
} }
...@@ -408,7 +406,7 @@ func (r *DynamoComponentDeploymentReconciler) reconcilePVC(ctx context.Context, ...@@ -408,7 +406,7 @@ func (r *DynamoComponentDeploymentReconciler) reconcilePVC(ctx context.Context,
return nil, err return nil, err
} }
pvc = constructPVC(crd, pvcConfig) pvc = constructPVC(crd, pvcConfig)
if err := controllerutil.SetControllerReference(crd, pvc, r.Scheme); err != nil { if err := controllerutil.SetControllerReference(crd, pvc, r.Client.Scheme()); err != nil {
logger.Error(err, "Failed to set controller reference", "pvc", pvc.Name) logger.Error(err, "Failed to set controller reference", "pvc", pvc.Name)
return nil, err return nil, err
} }
...@@ -458,23 +456,27 @@ func (r *DynamoComponentDeploymentReconciler) setStatusConditions(ctx context.Co ...@@ -458,23 +456,27 @@ func (r *DynamoComponentDeploymentReconciler) setStatusConditions(ctx context.Co
func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments(ctx context.Context, opt generateResourceOption) (modified bool, depl *appsv1.Deployment, err error) { func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments(ctx context.Context, opt generateResourceOption) (modified bool, depl *appsv1.Deployment, err error) {
containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment) containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment)
// create the main deployment // create the main deployment
modified, depl, err = createOrUpdateResource(ctx, r, generateResourceOption{ modified, depl, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*appsv1.Deployment, bool, error) {
dynamoComponentDeployment: opt.dynamoComponentDeployment, return r.generateDeployment(ctx, generateResourceOption{
dynamoComponent: opt.dynamoComponent, dynamoComponentDeployment: opt.dynamoComponentDeployment,
isStealingTrafficDebugModeEnabled: false, dynamoComponent: opt.dynamoComponent,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled, isStealingTrafficDebugModeEnabled: false,
}, r.generateDeployment) containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
})
})
if err != nil { if err != nil {
err = errors.Wrap(err, "create or update deployment") err = errors.Wrap(err, "create or update deployment")
return return
} }
// create the debug deployment // create the debug deployment
modified2, _, err := createOrUpdateResource(ctx, r, generateResourceOption{ modified2, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*appsv1.Deployment, bool, error) {
dynamoComponentDeployment: opt.dynamoComponentDeployment, return r.generateDeployment(ctx, generateResourceOption{
dynamoComponent: opt.dynamoComponent, dynamoComponentDeployment: opt.dynamoComponentDeployment,
isStealingTrafficDebugModeEnabled: true, dynamoComponent: opt.dynamoComponent,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled, isStealingTrafficDebugModeEnabled: true,
}, r.generateDeployment) containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
})
})
if err != nil { if err != nil {
err = errors.Wrap(err, "create or update debug deployment") err = errors.Wrap(err, "create or update debug deployment")
} }
...@@ -482,135 +484,6 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments( ...@@ -482,135 +484,6 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteDeployments(
return return
} }
//nolint:nakedret
func createOrUpdateResource[T client.Object](ctx context.Context, r *DynamoComponentDeploymentReconciler, opt generateResourceOption, generateResource func(ctx context.Context, opt generateResourceOption) (T, bool, error)) (modified bool, res T, err error) {
logs := log.FromContext(ctx)
resource, toDelete, err := generateResource(ctx, opt)
if err != nil {
return
}
resourceNamespace := resource.GetNamespace()
resourceName := resource.GetName()
resourceType := reflect.TypeOf(resource).Elem().Name()
logs = logs.WithValues("namespace", resourceNamespace, "resourceName", resourceName, "resourceType", resourceType)
// Retrieve the GroupVersionKind (GVK) of the desired object
gvk, err := apiutil.GVKForObject(resource, r.Client.Scheme())
if err != nil {
logs.Error(err, "Failed to get GVK for object")
return
}
// Create a new instance of the object
obj, err := r.Client.Scheme().New(gvk)
if err != nil {
logs.Error(err, "Failed to create a new object for GVK")
return
}
// Type assertion to ensure the object implements client.Object
oldResource, ok := obj.(T)
if !ok {
return
}
err = r.Get(ctx, types.NamespacedName{Name: resourceName, Namespace: resourceNamespace}, oldResource)
oldResourceIsNotFound := k8serrors.IsNotFound(err)
if err != nil && !oldResourceIsNotFound {
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Get%s", resourceType), "Failed to get %s %s: %s", resourceType, resourceNamespace, err)
logs.Error(err, "Failed to get HPA.")
return
}
err = nil
if oldResourceIsNotFound {
if toDelete {
logs.Info("Resource not found. Nothing to do.")
return
}
logs.Info("Resource not found. Creating a new one.")
err = errors.Wrapf(patch.DefaultAnnotator.SetLastAppliedAnnotation(resource), "set last applied annotation for resource %s", resourceName)
if err != nil {
logs.Error(err, "Failed to set last applied annotation.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, "SetLastAppliedAnnotation", "Failed to set last applied annotation for %s %s: %s", resourceType, resourceNamespace, err)
return
}
err = ctrl.SetControllerReference(opt.dynamoComponentDeployment, resource, r.Scheme)
if err != nil {
logs.Error(err, "Failed to set controller reference.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, "SetControllerReference", "Failed to set controller reference for %s %s: %s", resourceType, resourceNamespace, err)
return
}
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Creating a new %s %s", resourceType, resourceNamespace)
err = r.Create(ctx, resource)
if err != nil {
logs.Error(err, "Failed to create Resource.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Create%s", resourceType), "Failed to create %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s created.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Created %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s found.", resourceType))
if toDelete {
logs.Info(fmt.Sprintf("%s not found. Deleting the existing one.", resourceType))
err = r.Delete(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to delete %s.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Delete%s", resourceType), "Failed to delete %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s deleted.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Delete%s", resourceType), "Deleted %s %s", resourceType, resourceNamespace)
modified = true
return
}
var patchResult *patch.PatchResult
patchResult, err = patch.DefaultPatchMaker.Calculate(oldResource, resource)
if err != nil {
logs.Error(err, "Failed to calculate patch.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("CalculatePatch%s", resourceType), "Failed to calculate patch for %s %s: %s", resourceType, resourceNamespace, err)
return
}
if !patchResult.IsEmpty() {
logs.Info(fmt.Sprintf("%s spec is different. Updating %s. The patch result is: %s", resourceType, resourceType, patchResult.String()))
err = errors.Wrapf(patch.DefaultAnnotator.SetLastAppliedAnnotation(resource), "set last applied annotation for resource %s", resourceName)
if err != nil {
logs.Error(err, "Failed to set last applied annotation.")
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("SetLastAppliedAnnotation%s", resourceType), "Failed to set last applied annotation for %s %s: %s", resourceType, resourceNamespace, err)
return
}
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updating %s %s", resourceType, resourceNamespace)
resource.SetResourceVersion(oldResource.GetResourceVersion())
err = r.Update(ctx, resource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to update %s.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeWarning, fmt.Sprintf("Update%s", resourceType), "Failed to update %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s updated.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updated %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s spec is the same. Skipping update.", resourceType))
r.Recorder.Eventf(opt.dynamoComponentDeployment, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Skipping update %s %s", resourceType, resourceNamespace)
res = oldResource
}
}
return
}
func getResourceAnnotations(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string { func getResourceAnnotations(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string {
resourceAnnotations := dynamoComponentDeployment.Spec.Annotations resourceAnnotations := dynamoComponentDeployment.Spec.Annotations
if resourceAnnotations == nil { if resourceAnnotations == nil {
...@@ -654,40 +527,46 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx ...@@ -654,40 +527,46 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx
isDebugPodReceiveProductionTrafficEnabled := checkIfIsDebugPodReceiveProductionTrafficEnabled(resourceAnnotations) isDebugPodReceiveProductionTrafficEnabled := checkIfIsDebugPodReceiveProductionTrafficEnabled(resourceAnnotations)
containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment) containsStealingTrafficDebugModeEnabled := checkIfContainsStealingTrafficDebugModeEnabled(opt.dynamoComponentDeployment)
// main generic service // main generic service
modified, _, err = createOrUpdateResource(ctx, r, generateResourceOption{ modified, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
dynamoComponentDeployment: opt.dynamoComponentDeployment, return r.generateService(ctx, generateResourceOption{
dynamoComponent: opt.dynamoComponent, dynamoComponentDeployment: opt.dynamoComponentDeployment,
isStealingTrafficDebugModeEnabled: false, dynamoComponent: opt.dynamoComponent,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled, isStealingTrafficDebugModeEnabled: false,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled, isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
isGenericService: true, containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
}, r.generateService) isGenericService: true,
})
})
if err != nil { if err != nil {
return return
} }
// debug production service (if enabled) // debug production service (if enabled)
modified_, _, err := createOrUpdateResource(ctx, r, generateResourceOption{ modified_, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
dynamoComponentDeployment: opt.dynamoComponentDeployment, return r.generateService(ctx, generateResourceOption{
dynamoComponent: opt.dynamoComponent, dynamoComponentDeployment: opt.dynamoComponentDeployment,
isStealingTrafficDebugModeEnabled: false, dynamoComponent: opt.dynamoComponent,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled, isStealingTrafficDebugModeEnabled: false,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled, isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
isGenericService: false, containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
}, r.generateService) isGenericService: false,
})
})
if err != nil { if err != nil {
return return
} }
modified = modified || modified_ modified = modified || modified_
// debug service (if enabled) // debug service (if enabled)
modified_, _, err = createOrUpdateResource(ctx, r, generateResourceOption{ modified_, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*corev1.Service, bool, error) {
dynamoComponentDeployment: opt.dynamoComponentDeployment, return r.generateService(ctx, generateResourceOption{
dynamoComponent: opt.dynamoComponent, dynamoComponentDeployment: opt.dynamoComponentDeployment,
isStealingTrafficDebugModeEnabled: true, dynamoComponent: opt.dynamoComponent,
isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled, isStealingTrafficDebugModeEnabled: true,
containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled, isDebugPodReceiveProductionTraffic: isDebugPodReceiveProductionTrafficEnabled,
isGenericService: false, containsStealingTrafficDebugModeEnabled: containsStealingTrafficDebugModeEnabled,
}, r.generateService) isGenericService: false,
})
})
if err != nil { if err != nil {
return return
} }
...@@ -696,11 +575,15 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx ...@@ -696,11 +575,15 @@ func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteServices(ctx
} }
func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteIngress(ctx context.Context, opt generateResourceOption) (modified bool, err error) { func (r *DynamoComponentDeploymentReconciler) createOrUpdateOrDeleteIngress(ctx context.Context, opt generateResourceOption) (modified bool, err error) {
modified, _, err = createOrUpdateResource(ctx, r, opt, r.generateIngress) modified, _, err = commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*networkingv1.Ingress, bool, error) {
return r.generateIngress(ctx, opt)
})
if err != nil { if err != nil {
return return
} }
modified_, _, err := createOrUpdateResource(ctx, r, opt, r.generateVirtualService) modified_, _, err := commonController.SyncResource(ctx, r, opt.dynamoComponentDeployment, func(ctx context.Context) (*networkingv1beta1.VirtualService, bool, error) {
return r.generateVirtualService(ctx, opt)
})
if err != nil { if err != nil {
return return
} }
...@@ -964,7 +847,7 @@ type generateResourceOption struct { ...@@ -964,7 +847,7 @@ type generateResourceOption struct {
isGenericService bool isGenericService bool
} }
func (r *DynamoComponentDeploymentReconciler) generateHPA(ctx context.Context, opt generateResourceOption) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) { func (r *DynamoComponentDeploymentReconciler) generateHPA(opt generateResourceOption) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) {
labels := r.getKubeLabels(opt.dynamoComponentDeployment, opt.dynamoComponent) labels := r.getKubeLabels(opt.dynamoComponentDeployment, opt.dynamoComponent)
annotations := r.getKubeAnnotations(opt.dynamoComponentDeployment, opt.dynamoComponent) annotations := r.getKubeAnnotations(opt.dynamoComponentDeployment, opt.dynamoComponent)
...@@ -1151,6 +1034,7 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex ...@@ -1151,6 +1034,7 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
if opt.dynamoComponentDeployment.Spec.DynamoNamespace != nil && *opt.dynamoComponentDeployment.Spec.DynamoNamespace != "" { if opt.dynamoComponentDeployment.Spec.DynamoNamespace != nil && *opt.dynamoComponentDeployment.Spec.DynamoNamespace != "" {
args = append(args, fmt.Sprintf("--%s.ServiceArgs.dynamo.namespace=%s", opt.dynamoComponentDeployment.Spec.ServiceName, *opt.dynamoComponentDeployment.Spec.DynamoNamespace)) args = append(args, fmt.Sprintf("--%s.ServiceArgs.dynamo.namespace=%s", opt.dynamoComponentDeployment.Spec.ServiceName, *opt.dynamoComponentDeployment.Spec.DynamoNamespace))
} }
args = append(args, fmt.Sprintf("--%s.environment=%s", opt.dynamoComponentDeployment.Spec.ServiceName, KubernetesDeploymentStrategy))
} }
if len(opt.dynamoComponentDeployment.Spec.Envs) > 0 { if len(opt.dynamoComponentDeployment.Spec.Envs) > 0 {
...@@ -1505,7 +1389,7 @@ func getResourcesConfig(resources *dynamoCommon.Resources) (corev1.ResourceRequi ...@@ -1505,7 +1389,7 @@ func getResourcesConfig(resources *dynamoCommon.Resources) (corev1.ResourceRequi
} }
//nolint:nakedret //nolint:nakedret
func (r *DynamoComponentDeploymentReconciler) generateService(ctx context.Context, opt generateResourceOption) (kubeService *corev1.Service, toDelete bool, err error) { func (r *DynamoComponentDeploymentReconciler) generateService(_ context.Context, opt generateResourceOption) (kubeService *corev1.Service, toDelete bool, err error) {
var kubeName string var kubeName string
if opt.isGenericService { if opt.isGenericService {
kubeName = r.getGenericServiceName(opt.dynamoComponentDeployment, opt.dynamoComponent) kubeName = r.getGenericServiceName(opt.dynamoComponentDeployment, opt.dynamoComponent)
...@@ -1602,3 +1486,7 @@ func (r *DynamoComponentDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) ...@@ -1602,3 +1486,7 @@ func (r *DynamoComponentDeploymentReconciler) SetupWithManager(mgr ctrl.Manager)
m.Owns(&autoscalingv2.HorizontalPodAutoscaler{}) m.Owns(&autoscalingv2.HorizontalPodAutoscaler{})
return m.Complete(r) return m.Complete(r)
} }
func (r *DynamoComponentDeploymentReconciler) GetRecorder() record.EventRecorder {
return r.Recorder
}
...@@ -19,12 +19,12 @@ package controller ...@@ -19,12 +19,12 @@ package controller
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"dario.cat/mergo" "dario.cat/mergo"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record" "k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime" ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/builder"
...@@ -34,6 +34,7 @@ import ( ...@@ -34,6 +34,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/predicate"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
) )
...@@ -42,6 +43,8 @@ const ( ...@@ -42,6 +43,8 @@ const (
FailedState = "failed" FailedState = "failed"
ReadyState = "successful" ReadyState = "successful"
PendingState = "pending" PendingState = "pending"
DYN_DEPLOYMENT_CONFIG_ENV_VAR = "DYN_DEPLOYMENT_CONFIG"
) )
type etcdStorage interface { type etcdStorage interface {
...@@ -51,7 +54,6 @@ type etcdStorage interface { ...@@ -51,7 +54,6 @@ type etcdStorage interface {
// DynamoGraphDeploymentReconciler reconciles a DynamoGraphDeployment object // DynamoGraphDeploymentReconciler reconciles a DynamoGraphDeployment object
type DynamoGraphDeploymentReconciler struct { type DynamoGraphDeploymentReconciler struct {
client.Client client.Client
Scheme *runtime.Scheme
Config commonController.Config Config commonController.Config
Recorder record.EventRecorder Recorder record.EventRecorder
VirtualServiceGateway string VirtualServiceGateway string
...@@ -94,6 +96,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -94,6 +96,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if err != nil { if err != nil {
dynamoDeployment.SetState(FailedState) dynamoDeployment.SetState(FailedState)
message = err.Error() message = err.Error()
logger.Error(err, "Reconciliation failed")
} }
// update the CRD status condition // update the CRD status condition
dynamoDeployment.AddStatusCondition(metav1.Condition{ dynamoDeployment.AddStatusCondition(metav1.Condition{
...@@ -112,6 +115,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -112,6 +115,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
deleted, err := commonController.HandleFinalizer(ctx, dynamoDeployment, r.Client, r) deleted, err := commonController.HandleFinalizer(ctx, dynamoDeployment, r.Client, r)
if err != nil { if err != nil {
logger.Error(err, "failed to handle the finalizer")
reason = "failed_to_handle_the_finalizer" reason = "failed_to_handle_the_finalizer"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -122,6 +126,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -122,6 +126,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// fetch the dynamoGraphConfig // fetch the dynamoGraphConfig
dynamoGraphConfig, err := dynamo.GetDynamoGraphConfig(ctx, dynamoDeployment, r.Recorder) dynamoGraphConfig, err := dynamo.GetDynamoGraphConfig(ctx, dynamoDeployment, r.Recorder)
if err != nil { if err != nil {
logger.Error(err, "failed to get the DynamoGraphConfig")
reason = "failed_to_get_the_DynamoGraphConfig" reason = "failed_to_get_the_DynamoGraphConfig"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -129,6 +134,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -129,6 +134,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// generate the dynamoComponentsDeployments from the config // generate the dynamoComponentsDeployments from the config
dynamoComponentsDeployments, err := dynamo.GenerateDynamoComponentsDeployments(ctx, dynamoDeployment, dynamoGraphConfig, r.generateDefaultIngressSpec(dynamoDeployment)) dynamoComponentsDeployments, err := dynamo.GenerateDynamoComponentsDeployments(ctx, dynamoDeployment, dynamoGraphConfig, r.generateDefaultIngressSpec(dynamoDeployment))
if err != nil { if err != nil {
logger.Error(err, "failed to generate the DynamoComponentsDeployments")
reason = "failed_to_generate_the_DynamoComponentsDeployments" reason = "failed_to_generate_the_DynamoComponentsDeployments"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -138,6 +144,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -138,6 +144,7 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if _, ok := dynamoDeployment.Spec.Services[serviceName]; ok { if _, ok := dynamoDeployment.Spec.Services[serviceName]; ok {
err := mergo.Merge(&deployment.Spec.DynamoComponentDeploymentSharedSpec, dynamoDeployment.Spec.Services[serviceName].DynamoComponentDeploymentSharedSpec, mergo.WithOverride) err := mergo.Merge(&deployment.Spec.DynamoComponentDeploymentSharedSpec, dynamoDeployment.Spec.Services[serviceName].DynamoComponentDeploymentSharedSpec, mergo.WithOverride)
if err != nil { if err != nil {
logger.Error(err, "failed to merge the DynamoComponentsDeployments")
reason = "failed_to_merge_the_DynamoComponentsDeployments" reason = "failed_to_merge_the_DynamoComponentsDeployments"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -152,6 +159,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -152,6 +159,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
if len(dynamoDeployment.Spec.Envs) > 0 { if len(dynamoDeployment.Spec.Envs) > 0 {
deployment.Spec.Envs = mergeEnvs(dynamoDeployment.Spec.Envs, deployment.Spec.Envs) deployment.Spec.Envs = mergeEnvs(dynamoDeployment.Spec.Envs, deployment.Spec.Envs)
} }
err := updateDynDeploymentConfig(deployment, consts.DynamoServicePort)
if err != nil {
logger.Error(err, fmt.Sprintf("Failed to update the %v env var", DYN_DEPLOYMENT_CONFIG_ENV_VAR))
return ctrl.Result{}, err
}
} }
// reconcile the dynamoComponent // reconcile the dynamoComponent
...@@ -165,12 +177,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -165,12 +177,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
DynamoComponent: dynamoDeployment.Spec.DynamoGraph, DynamoComponent: dynamoDeployment.Spec.DynamoGraph,
}, },
} }
if err := ctrl.SetControllerReference(dynamoDeployment, dynamoComponent, r.Scheme); err != nil { _, dynamoComponent, err = commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*nvidiacomv1alpha1.DynamoComponent, bool, error) {
reason = "failed_to_set_the_controller_reference_for_the_DynamoComponent" return dynamoComponent, false, nil
return ctrl.Result{}, err })
}
dynamoComponent, err = commonController.SyncResource(ctx, r.Client, dynamoComponent, false)
if err != nil { if err != nil {
logger.Error(err, "failed to sync the DynamoComponent")
reason = "failed_to_sync_the_DynamoComponent" reason = "failed_to_sync_the_DynamoComponent"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -186,12 +197,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr ...@@ -186,12 +197,11 @@ func (r *DynamoGraphDeploymentReconciler) Reconcile(ctx context.Context, req ctr
// reconcile the dynamoComponentsDeployments // reconcile the dynamoComponentsDeployments
for serviceName, dynamoComponentDeployment := range dynamoComponentsDeployments { for serviceName, dynamoComponentDeployment := range dynamoComponentsDeployments {
logger.Info("Reconciling the DynamoComponentDeployment", "serviceName", serviceName, "dynamoComponentDeployment", dynamoComponentDeployment) logger.Info("Reconciling the DynamoComponentDeployment", "serviceName", serviceName, "dynamoComponentDeployment", dynamoComponentDeployment)
if err := ctrl.SetControllerReference(dynamoDeployment, dynamoComponentDeployment, r.Scheme); err != nil { _, dynamoComponentDeployment, err = commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*nvidiacomv1alpha1.DynamoComponentDeployment, bool, error) {
reason = "failed_to_set_the_controller_reference_for_the_DynamoComponentDeployment" return dynamoComponentDeployment, false, nil
return ctrl.Result{}, err })
}
dynamoComponentDeployment, err = commonController.SyncResource(ctx, r.Client, dynamoComponentDeployment, false)
if err != nil { if err != nil {
logger.Error(err, "failed to sync the DynamoComponentDeployment")
reason = "failed_to_sync_the_DynamoComponentDeployment" reason = "failed_to_sync_the_DynamoComponentDeployment"
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -265,6 +275,39 @@ func mergeEnvs(common, specific []corev1.EnvVar) []corev1.EnvVar { ...@@ -265,6 +275,39 @@ func mergeEnvs(common, specific []corev1.EnvVar) []corev1.EnvVar {
return merged return merged
} }
// updateDynDeploymentConfig updates the DYN_DEPLOYMENT_CONFIG env var for the given dynamoDeploymentComponent
// It updates the port for the given service in the DYN_DEPLOYMENT_CONFIG env var (if it is the main component)
func updateDynDeploymentConfig(dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment, newPort int) error {
if dynamoDeploymentComponent.IsMainComponent() {
for i, env := range dynamoDeploymentComponent.Spec.Envs {
if env.Name == DYN_DEPLOYMENT_CONFIG_ENV_VAR {
var config map[string]any
if err := json.Unmarshal([]byte(env.Value), &config); err != nil {
return fmt.Errorf("failed to unmarshal %v: %w", DYN_DEPLOYMENT_CONFIG_ENV_VAR, err)
}
// Safely navigate and update the config
if serviceConfig, ok := config[dynamoDeploymentComponent.Spec.ServiceName].(map[string]any); ok {
if _, portExists := serviceConfig["port"]; portExists {
serviceConfig["port"] = newPort
}
}
// Marshal back to JSON string
updated, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal updated config: %w", err)
}
// Update env var
dynamoDeploymentComponent.Spec.Envs[i].Value = string(updated)
break
}
}
}
return nil
}
func (r *DynamoGraphDeploymentReconciler) FinalizeResource(ctx context.Context, dynamoDeployment *nvidiacomv1alpha1.DynamoGraphDeployment) error { func (r *DynamoGraphDeploymentReconciler) FinalizeResource(ctx context.Context, dynamoDeployment *nvidiacomv1alpha1.DynamoGraphDeployment) error {
// for now doing nothing // for now doing nothing
return nil return nil
...@@ -294,3 +337,7 @@ func (r *DynamoGraphDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) err ...@@ -294,3 +337,7 @@ func (r *DynamoGraphDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) err
WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config)). WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config)).
Complete(r) Complete(r)
} }
func (r *DynamoGraphDeploymentReconciler) GetRecorder() record.EventRecorder {
return r.Recorder
}
...@@ -22,6 +22,8 @@ import ( ...@@ -22,6 +22,8 @@ import (
"sort" "sort"
"testing" "testing"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/bsm/gomega"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
) )
...@@ -80,3 +82,125 @@ func Test_mergeEnvs(t *testing.T) { ...@@ -80,3 +82,125 @@ func Test_mergeEnvs(t *testing.T) {
}) })
} }
} }
func Test_updateDynDeploymentConfig(t *testing.T) {
type args struct {
dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment
newPort int
}
tests := []struct {
name string
args args
want []corev1.EnvVar
wantErr bool
}{
{
name: "main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 3000,
},
want: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":3000},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
{
name: "not main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Other",
Envs: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 3000,
},
want: []corev1.EnvVar{
{
Name: "DYN_DEPLOYMENT_CONFIG",
Value: `{"Frontend":{"port":8080},"Planner":{"environment":"kubernetes"}}`,
},
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
{
name: "no DYN_DEPLOYMENT_CONFIG env variable",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
Name: "OTHER",
Value: `value`,
},
},
},
},
},
newPort: 8080,
},
want: []corev1.EnvVar{
{
Name: "OTHER",
Value: `value`,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := updateDynDeploymentConfig(tt.args.dynamoDeploymentComponent, tt.args.newPort)
if (err != nil) != tt.wantErr {
t.Errorf("updateDynDeploymentConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
g := gomega.NewGomegaWithT(t)
g.Expect(tt.args.dynamoDeploymentComponent.Spec.Envs).To(gomega.Equal(tt.want))
})
}
}
...@@ -22,12 +22,19 @@ import ( ...@@ -22,12 +22,19 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"sort" "sort"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil" "sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/log"
) )
const ( const (
...@@ -35,76 +42,241 @@ const ( ...@@ -35,76 +42,241 @@ const (
NvidiaAnnotationHashKey = "nvidia.com/last-applied-hash" NvidiaAnnotationHashKey = "nvidia.com/last-applied-hash"
) )
type Resource interface { type Reconciler interface {
client.Object client.Client
GetSpec() any GetRecorder() record.EventRecorder
SetSpec(spec any)
} }
func SyncResource[T Resource](ctx context.Context, c client.Client, desired T, createOnly bool) (T, error) { // ResourceGenerator is a function that generates a resource.
// it must return the resource, a boolean indicating if the resource should be deleted, and an error
// if the resource should be deleted, the returned resource must contain the necessary information to delete it (name and namespace)
type ResourceGenerator[T client.Object] func(ctx context.Context) (T, bool, error)
//nolint:nakedret
func SyncResource[T client.Object](ctx context.Context, r Reconciler, parentResource client.Object, generateResource ResourceGenerator[T]) (modified bool, res T, err error) {
logs := log.FromContext(ctx)
resource, toDelete, err := generateResource(ctx)
if err != nil {
return
}
resourceNamespace := resource.GetNamespace()
resourceName := resource.GetName()
resourceType := reflect.TypeOf(resource).Elem().Name()
logs = logs.WithValues("namespace", resourceNamespace, "resourceName", resourceName, "resourceType", resourceType)
// Retrieve the GroupVersionKind (GVK) of the desired object // Retrieve the GroupVersionKind (GVK) of the desired object
gvk, err := apiutil.GVKForObject(desired, c.Scheme()) gvk, err := apiutil.GVKForObject(resource, r.Scheme())
if err != nil { if err != nil {
return desired, fmt.Errorf("failed to get GVK for object: %w", err) logs.Error(err, "Failed to get GVK for object")
return
} }
// Create a new instance of the object // Create a new instance of the object
obj, err := c.Scheme().New(gvk) obj, err := r.Scheme().New(gvk)
if err != nil { if err != nil {
return desired, fmt.Errorf("failed to create a new object for GVK %s: %w", gvk, err) logs.Error(err, "Failed to create a new object for GVK")
return
} }
// Type assertion to ensure the object implements client.Object // Type assertion to ensure the object implements client.Object
current, ok := obj.(T) oldResource, ok := obj.(T)
if !ok { if !ok {
return desired, fmt.Errorf("failed to cast object to the expected type %T", desired) return
} }
namespacedName := types.NamespacedName{
Name: desired.GetName(), err = r.Get(ctx, types.NamespacedName{Name: resourceName, Namespace: resourceNamespace}, oldResource)
Namespace: desired.GetNamespace(), oldResourceIsNotFound := errors.IsNotFound(err)
if err != nil && !oldResourceIsNotFound {
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Get%s", resourceType), "Failed to get %s %s: %s", resourceType, resourceNamespace, err)
logs.Error(err, "Failed to get resource.")
return
} }
err = nil
// Retrieve the existing resource if oldResourceIsNotFound {
err = c.Get(ctx, namespacedName, current) if toDelete {
if err != nil { logs.Info("Resource not found. Nothing to do.")
if errors.IsNotFound(err) { return
// If the resource doesn't exist, create it }
if err := c.Create(ctx, desired); err != nil { logs.Info("Resource not found. Creating a new one.")
return desired, fmt.Errorf("failed to create resource: %w", err)
err = ctrl.SetControllerReference(parentResource, resource, r.Scheme())
if err != nil {
logs.Error(err, "Failed to set controller reference.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "SetControllerReference", "Failed to set controller reference for %s %s: %s", resourceType, resourceNamespace, err)
return
}
var hash string
hash, err = GetSpecHash(resource)
if err != nil {
logs.Error(err, "Failed to get spec hash.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "GetSpecHash", "Failed to get spec hash for %s %s: %s", resourceType, resourceNamespace, err)
return
}
updateHashAnnotation(resource, hash)
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Creating a new %s %s", resourceType, resourceNamespace)
err = r.Create(ctx, resource)
if err != nil {
logs.Error(err, "Failed to create Resource.")
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Create%s", resourceType), "Failed to create %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s created.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Created %s %s", resourceType, resourceNamespace)
modified = true
res = resource
} else {
logs.Info(fmt.Sprintf("%s found.", resourceType))
if toDelete {
logs.Info(fmt.Sprintf("%s not found. Deleting the existing one.", resourceType))
err = r.Delete(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to delete %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Delete%s", resourceType), "Failed to delete %s %s: %s", resourceType, resourceNamespace, err)
return
} }
return desired, nil logs.Info(fmt.Sprintf("%s deleted.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Delete%s", resourceType), "Deleted %s %s", resourceType, resourceNamespace)
modified = true
return
}
// Check if the Spec has changed and update if necessary
var newHash *string
newHash, err = IsSpecChanged(oldResource, resource)
if err != nil {
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CalculatePatch%s", resourceType), "Failed to calculate patch for %s %s: %s", resourceType, resourceNamespace, err)
return false, resource, fmt.Errorf("failed to check if spec has changed: %w", err)
} }
return desired, fmt.Errorf("failed to get resource: %w", err) if newHash != nil {
// update the spec of the current object with the desired spec
err = CopySpec(resource, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to copy spec for %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CopySpec%s", resourceType), "Failed to copy spec for %s %s: %s", resourceType, resourceNamespace, err)
return
}
updateHashAnnotation(oldResource, *newHash)
err = r.Update(ctx, oldResource)
if err != nil {
logs.Error(err, fmt.Sprintf("Failed to update %s.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Update%s", resourceType), "Failed to update %s %s: %s", resourceType, resourceNamespace, err)
return
}
logs.Info(fmt.Sprintf("%s updated.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updated %s %s", resourceType, resourceNamespace)
modified = true
res = oldResource
} else {
logs.Info(fmt.Sprintf("%s spec is the same. Skipping update.", resourceType))
r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Skipping update %s %s", resourceType, resourceNamespace)
res = oldResource
}
}
return
}
// CopySpec copies only the Spec field from source to destination using Unstructured
func CopySpec(source, destination client.Object) error {
// Convert source to unstructured
sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(source)
if err != nil {
return err
}
sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}
// Convert destination to unstructured
destMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(destination)
if err != nil {
return err
}
destUnstructured := &unstructured.Unstructured{Object: destMap}
// Extract only the spec from source
sourceSpec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
if err != nil {
return err
}
if !found {
return fmt.Errorf("spec not found in source object")
} }
if createOnly { // Set the spec in the destination
return current, nil err = unstructured.SetNestedField(destUnstructured.Object, sourceSpec, "spec")
if err != nil {
return err
} }
// Check if the Spec has changed and update if necessary // Convert back to the original object
if IsSpecChanged(current, desired) { return runtime.DefaultUnstructuredConverter.FromUnstructured(destUnstructured.Object, destination)
// update the spec of the current object with the desired spec }
desired.SetResourceVersion(current.GetResourceVersion())
if err := c.Update(ctx, desired); err != nil { func getSpec(obj client.Object) (any, error) {
return desired, fmt.Errorf("failed to update resource: %w", err) // Convert source to unstructured
sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj)
if err != nil {
return nil, err
}
sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}
// Extract only the spec from source
spec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
if err != nil {
return nil, err
}
if !found {
return nil, nil
}
return spec, nil
}
// IsSpecChanged returns the new hash if the spec has changed between the existing one
func IsSpecChanged(current client.Object, desired client.Object) (*string, error) {
hashStr, err := GetSpecHash(desired)
if err != nil {
return nil, err
}
if currentHash, ok := current.GetAnnotations()[NvidiaAnnotationHashKey]; ok {
if currentHash == hashStr {
return nil, nil
} }
} }
return &hashStr, nil
}
// Return the updated object func GetSpecHash(obj client.Object) (string, error) {
return current, nil spec, err := getSpec(obj)
if err != nil {
return "", err
}
return GetResourceHash(spec)
}
func updateHashAnnotation(obj client.Object, hash string) {
annotations := obj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{}
}
annotations[NvidiaAnnotationHashKey] = hash
obj.SetAnnotations(annotations)
} }
// GetResourceHash returns a consistent hash for the given object spec // GetResourceHash returns a consistent hash for the given object spec
func GetResourceHash(obj any) string { func GetResourceHash(obj any) (string, error) {
// Convert obj to a map[string]interface{} // Convert obj to a map[string]interface{}
objMap, err := json.Marshal(obj) objMap, err := json.Marshal(obj)
if err != nil { if err != nil {
panic(err) return "", err
} }
var objData map[string]interface{} var objData map[string]interface{}
if err := json.Unmarshal(objMap, &objData); err != nil { if err := json.Unmarshal(objMap, &objData); err != nil {
panic(err) return "", err
} }
// Sort keys to ensure consistent serialization // Sort keys to ensure consistent serialization
...@@ -113,56 +285,13 @@ func GetResourceHash(obj any) string { ...@@ -113,56 +285,13 @@ func GetResourceHash(obj any) string {
// Serialize to JSON // Serialize to JSON
serialized, err := json.Marshal(sortedObjData) serialized, err := json.Marshal(sortedObjData)
if err != nil { if err != nil {
panic(err) return "", err
} }
// Compute the hash // Compute the hash
hasher := sha256.New() hasher := sha256.New()
hasher.Write(serialized) hasher.Write(serialized)
return fmt.Sprintf("%x", hasher.Sum(nil)) return fmt.Sprintf("%x", hasher.Sum(nil)), nil
}
// IsSpecChanged returns true if the spec has changed between the existing one
// and the new resource spec compared by hash.
func IsSpecChanged(current Resource, desired Resource) bool {
if current == nil && desired != nil {
return true
}
hashStr := GetResourceHash(desired.GetSpec())
foundHashAnnotation := false
currentAnnotations := current.GetAnnotations()
desiredAnnotations := desired.GetAnnotations()
if currentAnnotations == nil {
currentAnnotations = map[string]string{}
}
if desiredAnnotations == nil {
desiredAnnotations = map[string]string{}
}
for annotation, value := range currentAnnotations {
if annotation == NvidiaAnnotationHashKey {
if value != hashStr {
// Update annotation to be added to resource as per new spec and indicate spec update is required
desiredAnnotations[NvidiaAnnotationHashKey] = hashStr
desired.SetAnnotations(desiredAnnotations)
return true
}
foundHashAnnotation = true
break
}
}
if !foundHashAnnotation {
// Update annotation to be added to resource as per new spec and indicate spec update is required
desiredAnnotations[NvidiaAnnotationHashKey] = hashStr
desired.SetAnnotations(desiredAnnotations)
return true
}
return false
} }
// SortKeys recursively sorts the keys of a map to ensure consistent serialization // SortKeys recursively sorts the keys of a map to ensure consistent serialization
......
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"emperror.dev/errors" "emperror.dev/errors"
apiStoreClient "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/api_store_client" apiStoreClient "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/api_store_client"
compounaiCommon "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconfig "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/config" commonconfig "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/config"
...@@ -42,11 +42,17 @@ import ( ...@@ -42,11 +42,17 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
const (
ComponentTypePlanner = "planner"
PlannerServiceAccountName = "planner-serviceaccount"
)
// ServiceConfig represents the YAML configuration structure for a service // ServiceConfig represents the YAML configuration structure for a service
type DynamoConfig struct { type DynamoConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Namespace string `yaml:"namespace"` Namespace string `yaml:"namespace"`
Name string `yaml:"name"` Name string `yaml:"name"`
ComponentType string `yaml:"component_type,omitempty"`
} }
type Resources struct { type Resources struct {
...@@ -230,6 +236,7 @@ func GetDynamoGraphConfig(ctx context.Context, dynamoDeployment *v1alpha1.Dynamo ...@@ -230,6 +236,7 @@ func GetDynamoGraphConfig(ctx context.Context, dynamoDeployment *v1alpha1.Dynamo
func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphDeployment *v1alpha1.DynamoGraphDeployment, config *DynamoGraphConfig, ingressSpec *v1alpha1.IngressSpec) (map[string]*v1alpha1.DynamoComponentDeployment, error) { func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphDeployment *v1alpha1.DynamoGraphDeployment, config *DynamoGraphConfig, ingressSpec *v1alpha1.IngressSpec) (map[string]*v1alpha1.DynamoComponentDeployment, error) {
dynamoServices := make(map[string]string) dynamoServices := make(map[string]string)
deployments := make(map[string]*v1alpha1.DynamoComponentDeployment) deployments := make(map[string]*v1alpha1.DynamoComponentDeployment)
graphDynamoNamespace := ""
for _, service := range config.Services { for _, service := range config.Services {
deployment := &v1alpha1.DynamoComponentDeployment{} deployment := &v1alpha1.DynamoComponentDeployment{}
deployment.Name = fmt.Sprintf("%s-%s", parentDynamoGraphDeployment.Name, strings.ToLower(service.Name)) deployment.Name = fmt.Sprintf("%s-%s", parentDynamoGraphDeployment.Name, strings.ToLower(service.Name))
...@@ -252,6 +259,18 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD ...@@ -252,6 +259,18 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD
deployment.Spec.DynamoNamespace = &dynamoNamespace deployment.Spec.DynamoNamespace = &dynamoNamespace
dynamoServices[service.Name] = fmt.Sprintf("%s/%s", service.Config.Dynamo.Name, dynamoNamespace) dynamoServices[service.Name] = fmt.Sprintf("%s/%s", service.Config.Dynamo.Name, dynamoNamespace)
labels[commonconsts.KubeLabelDynamoNamespace] = dynamoNamespace labels[commonconsts.KubeLabelDynamoNamespace] = dynamoNamespace
// we check that all dynamo components are in the same namespace
// this is needed for the planner to work correctly
// this check will be removed when the global planner will be implemented
if graphDynamoNamespace != "" && graphDynamoNamespace != dynamoNamespace {
return nil, fmt.Errorf("different namespaces for the same graph, expected %s, got %s", graphDynamoNamespace, dynamoNamespace)
}
graphDynamoNamespace = dynamoNamespace
if service.Config.Dynamo.ComponentType == ComponentTypePlanner {
deployment.Spec.ExtraPodSpec = &common.ExtraPodSpec{
ServiceAccountName: PlannerServiceAccountName,
}
}
} }
// Check http_exposed independently // Check http_exposed independently
if config.EntryService == service.Name && service.Config.HttpExposed { if config.EntryService == service.Name && service.Config.HttpExposed {
...@@ -260,14 +279,14 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD ...@@ -260,14 +279,14 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD
} }
if service.Config.Resources != nil { if service.Config.Resources != nil {
deployment.Spec.Resources = &compounaiCommon.Resources{ deployment.Spec.Resources = &common.Resources{
Requests: &compounaiCommon.ResourceItem{ Requests: &common.ResourceItem{
CPU: service.Config.Resources.CPU, CPU: service.Config.Resources.CPU,
Memory: service.Config.Resources.Memory, Memory: service.Config.Resources.Memory,
GPU: service.Config.Resources.GPU, GPU: service.Config.Resources.GPU,
Custom: service.Config.Resources.Custom, Custom: service.Config.Resources.Custom,
}, },
Limits: &compounaiCommon.ResourceItem{ Limits: &common.ResourceItem{
CPU: service.Config.Resources.CPU, CPU: service.Config.Resources.CPU,
Memory: service.Config.Resources.Memory, Memory: service.Config.Resources.Memory,
GPU: service.Config.Resources.GPU, GPU: service.Config.Resources.GPU,
......
...@@ -507,6 +507,141 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { ...@@ -507,6 +507,141 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "Test GenerateDynamoComponentsDeployments planner",
args: args{
parentDynamoGraphDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
DynamoGraph: "dynamocomponent:ac4e234",
},
},
config: &DynamoGraphConfig{
DynamoTag: "dynamocomponent:MyService1",
Services: []ServiceConfig{
{
Name: "service1",
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "default",
Name: "service1",
ComponentType: ComponentTypePlanner,
},
Resources: &Resources{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
},
},
},
},
ingressSpec: &v1alpha1.IngressSpec{},
},
want: map[string]*v1alpha1.DynamoComponentDeployment{
"service1": {
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment-service1",
Namespace: "default",
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponent: "service1",
commonconsts.KubeLabelDynamoNamespace: "default",
},
},
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponent: "dynamocomponent:ac4e234",
DynamoTag: "dynamocomponent:MyService1",
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "service1",
DynamoNamespace: &[]string{"default"}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
Limits: &compounaiCommon.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
},
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponent: "service1",
commonconsts.KubeLabelDynamoNamespace: "default",
},
ExtraPodSpec: &compounaiCommon.ExtraPodSpec{
ServiceAccountName: PlannerServiceAccountName,
},
Autoscaling: &v1alpha1.Autoscaling{},
},
},
},
},
wantErr: false,
},
{
name: "Test GenerateDynamoComponentsDeployments dynamo dependency, different namespace",
args: args{
parentDynamoGraphDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamographdeployment",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
DynamoGraph: "dynamocomponent:ac4e234",
},
},
config: &DynamoGraphConfig{
DynamoTag: "dynamocomponent:MyService2",
EntryService: "service1",
Services: []ServiceConfig{
{
Name: "service1",
Dependencies: []map[string]string{{"service": "service2"}},
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "namespace1",
Name: "service1",
},
Resources: &Resources{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
Autoscaling: &Autoscaling{
MinReplicas: 1,
MaxReplicas: 5,
},
},
},
{
Name: "service2",
Dependencies: []map[string]string{},
Config: Config{
Dynamo: &DynamoConfig{
Enabled: true,
Namespace: "namespace2",
Name: "service2",
},
},
},
},
},
ingressSpec: &v1alpha1.IngressSpec{},
},
want: nil,
wantErr: true,
},
{ {
name: "Test GenerateDynamoComponentsDeployments ingress enabled by default", name: "Test GenerateDynamoComponentsDeployments ingress enabled by default",
args: args{ args: args{
......
...@@ -173,6 +173,11 @@ def main( ...@@ -173,6 +173,11 @@ def main(
if service_name and service_name != service.name: if service_name and service_name != service.name:
service = service.find_dependent_by_name(service_name) service = service.find_dependent_by_name(service_name)
# Set namespace in dynamo_context if service is a dynamo component
if service.is_dynamo_component():
namespace, _ = service.dynamo_address()
dynamo_context["namespace"] = namespace
configure_dynamo_logging(service_name=service_name, worker_id=worker_id) configure_dynamo_logging(service_name=service_name, worker_id=worker_id)
if runner_map: if runner_map:
BentoMLContainer.remote_runner_mapping.set( BentoMLContainer.remote_runner_mapping.set(
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
# WARNING: internal # WARNING: internal
...@@ -34,6 +35,16 @@ T = TypeVar("T", bound=object) ...@@ -34,6 +35,16 @@ T = TypeVar("T", bound=object)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ComponentType(str, Enum):
"""Types of Dynamo components"""
PLANNER = "planner"
# Future types can be added here like:
# METRICS = "metrics"
# MONITOR = "monitor"
# etc.
class RuntimeLinkedServices: class RuntimeLinkedServices:
""" """
A class to track the linked services in the runtime. A class to track the linked services in the runtime.
...@@ -68,6 +79,9 @@ class DynamoConfig: ...@@ -68,6 +79,9 @@ class DynamoConfig:
name: str | None = None name: str | None = None
namespace: str | None = None namespace: str | None = None
custom_lease: LeaseConfig | None = None custom_lease: LeaseConfig | None = None
component_type: ComponentType | None = (
None # Indicates if this is a meta/system component
)
@dataclass @dataclass
......
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from components.planner_service import Planner
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker from components.worker import VllmWorker
from fastapi import FastAPI from fastapi import FastAPI
...@@ -60,6 +61,7 @@ class FrontendConfig(BaseModel): ...@@ -60,6 +61,7 @@ class FrontendConfig(BaseModel):
app=FastAPI(title="LLM Example"), app=FastAPI(title="LLM Example"),
) )
class Frontend: class Frontend:
planner = depends(Planner)
worker = depends(VllmWorker) worker = depends(VllmWorker)
processor = depends(Processor) processor = depends(Processor)
......
...@@ -30,7 +30,7 @@ from tensorboardX import SummaryWriter ...@@ -30,7 +30,7 @@ from tensorboardX import SummaryWriter
from utils.prefill_queue import PrefillQueue from utils.prefill_queue import PrefillQueue
from dynamo.llm import KvMetricsAggregator from dynamo.llm import KvMetricsAggregator
from dynamo.planner import LocalConnector from dynamo.planner import KubernetesConnector, LocalConnector
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -53,7 +53,12 @@ class Planner: ...@@ -53,7 +53,12 @@ class Planner:
self.runtime = runtime self.runtime = runtime
self.args = args self.args = args
self.namespace = args.namespace self.namespace = args.namespace
self.connector = LocalConnector(args.namespace, runtime) if args.environment == "local":
self.connector = LocalConnector(args.namespace, runtime)
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(args.namespace)
else:
raise ValueError(f"Invalid environment: {args.environment}")
self._prefill_queue_nats_server = os.getenv( self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222" "NATS_SERVER", "nats://localhost:4222"
...@@ -355,7 +360,8 @@ class Planner: ...@@ -355,7 +360,8 @@ class Planner:
await asyncio.sleep(self.args.metric_pulling_interval / 10) await asyncio.sleep(self.args.metric_pulling_interval / 10)
@dynamo_worker() # @dynamo_worker()
# TODO: let's make it such that planner still works via CLI invokation
async def start_planner(runtime: DistributedRuntime, args: argparse.Namespace): async def start_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args) planner = Planner(runtime, args)
console = Console() console = Console()
...@@ -465,5 +471,11 @@ if __name__ == "__main__": ...@@ -465,5 +471,11 @@ if __name__ == "__main__":
default=1, default=1,
help="Number of GPUs per prefill engine", help="Number of GPUs per prefill engine",
) )
parser.add_argument(
"--environment",
type=str,
default="local",
help="Environment to run the planner in (local, kubernetes)",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(start_planner(args)) asyncio.run(dynamo_worker()(start_planner)(args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pydantic import BaseModel
from components.planner import start_planner # type: ignore[attr-defined]
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
class RequestType(BaseModel):
text: str
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"component_type": "planner",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
)
class Planner:
def __init__(self):
configure_dynamo_logging(service_name="Planner")
logger.info("Starting planner")
self.runtime = dynamo_context["runtime"]
config = ServiceConfig.get_instance()
# Get namespace directly from dynamo_context as it contains the active namespace
self.namespace = dynamo_context["namespace"]
self.environment = config.get("Planner", {}).get("environment", "local")
self.no_operation = config.get("Planner", {}).get("no-operation", True)
# Create args with all parameters from planner.py, using defaults except for namespace and environment
self.args = argparse.Namespace(
namespace=self.namespace,
environment=self.environment,
served_model_name="vllm",
no_operation=self.no_operation,
log_dir=None,
adjustment_interval=10,
metric_pulling_interval=1,
max_gpu_budget=8,
min_endpoint=1,
decode_kv_scale_up_threshold=0.9,
decode_kv_scale_down_threshold=0.5,
prefill_queue_scale_up_threshold=5,
prefill_queue_scale_down_threshold=0.2,
decode_engine_num_gpu=1,
prefill_engine_num_gpu=1,
)
@async_on_start
async def async_init(self):
import asyncio
await asyncio.sleep(60)
logger.info("Calling start_planner")
await start_planner(self.runtime, self.args)
logger.info("Planner started")
@dynamo_endpoint()
async def generate(self, request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment